in opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java [86:214]
public void learn() {
int currentEpoch = -1;
int n = 0;
int p = 0;
// memory variables for Adagrad
INDArray mWxh = Nd4j.zerosLike(wxh);
INDArray mWxh2 = Nd4j.zerosLike(wxh2);
INDArray mWhh = Nd4j.zerosLike(whh);
INDArray mWhh2 = Nd4j.zerosLike(whh2);
INDArray mWh2y = Nd4j.zerosLike(wh2y);
INDArray mbh = Nd4j.zerosLike(bh);
INDArray mbh2 = Nd4j.zerosLike(bh2);
INDArray mby = Nd4j.zerosLike(by);
// loss at iteration 0
double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;
while (true) {
// prepare inputs (we're sweeping from left to right in steps seqLength long)
if (p + seqLength + 1 >= data.size() || n == 0) {
hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
hPrev2 = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
p = 0; // go from start of data
currentEpoch++;
if (currentEpoch == epochs) {
System.out.println("training finished: e:" + epochs + ", l: " + smoothLoss + ", h:(" + learningRate + ", " + seqLength + ", " + hiddenLayerSize + ")");
break;
}
}
INDArray inputs = getSequence(p);
INDArray targets = getSequence(p + 1);
// sample from the model now and then
if (n % 1000 == 0 && n > 0) {
for (int i = 0; i < 3; i++) {
String txt = sample(inputs.getInt(0));
System.out.printf("\n---\n %s \n----\n", txt);
}
}
INDArray dWxh = Nd4j.zerosLike(wxh);
INDArray dWxh2 = Nd4j.zerosLike(wxh2);
INDArray dWhh = Nd4j.zerosLike(whh);
INDArray dWhh2 = Nd4j.zerosLike(whh2);
INDArray dWh2y = Nd4j.zerosLike(wh2y);
INDArray dbh = Nd4j.zerosLike(bh);
INDArray dbh2 = Nd4j.zerosLike(bh);
INDArray dby = Nd4j.zerosLike(by);
// forward seqLength characters through the net and fetch gradient
double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby);
double newLoss = smoothLoss * 0.999 + loss * 0.001;
if (newLoss > smoothLoss) {
learningRate *= 0.999 ;
}
smoothLoss = newLoss;
if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
System.out.println("loss is " + smoothLoss + "(" + loss + ") (over/underflow occurred, try adjusting hyperparameters)");
break;
}
if (n % 100 == 0) {
System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
}
if (n % batch == 0) {
if (rmsProp) {
// perform parameter update with RMSprop
mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh));
wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
mWxh2 = mWxh2.mul(decay).add(1 - decay).mul((dWxh2).mul(dWxh2));
wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));
mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh));
whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));
mWhh2 = mWhh2.mul(decay).add(1 - decay).mul((dWhh2).mul(dWhh2));
whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));
mbh2 = mbh2.mul(decay).add(1 - decay).mul((dbh2).mul(dbh2));
bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));
mWh2y = mWh2y.mul(decay).add(1 - decay).mul((dWh2y).mul(dWh2y));
wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));
mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh));
bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));
mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby));
by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
} else {
// perform parameter update with Adagrad
mWxh.addi(dWxh.mul(dWxh));
wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
mWxh2.addi(dWxh2.mul(dWxh2));
wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));
mWhh.addi(dWhh.mul(dWhh));
whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));
mWhh2.addi(dWhh2.mul(dWhh2));
whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));
mbh2.addi(dbh2.mul(dbh2));
bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));
mWh2y.addi(dWh2y.mul(dWh2y));
wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));
mbh.addi(dbh.mul(dbh));
bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));
mby.addi(dby.mul(dby));
by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
}
}
p += seqLength; // move data pointer
n++; // iteration counter
}
}