public void learn()

in opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java [86:214]


  public void learn() {

    int currentEpoch = -1;

    int n = 0;
    int p = 0;

    // memory variables for Adagrad
    INDArray mWxh = Nd4j.zerosLike(wxh);
    INDArray mWxh2 = Nd4j.zerosLike(wxh2);
    INDArray mWhh = Nd4j.zerosLike(whh);
    INDArray mWhh2 = Nd4j.zerosLike(whh2);
    INDArray mWh2y = Nd4j.zerosLike(wh2y);

    INDArray mbh = Nd4j.zerosLike(bh);
    INDArray mbh2 = Nd4j.zerosLike(bh2);
    INDArray mby = Nd4j.zerosLike(by);

    // loss at iteration 0
    double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;

    while (true) {
      // prepare inputs (we're sweeping from left to right in steps seqLength long)
      if (p + seqLength + 1 >= data.size() || n == 0) {
        hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
        hPrev2 = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
        p = 0; // go from start of data
        currentEpoch++;
        if (currentEpoch == epochs) {
          System.out.println("training finished: e:" + epochs + ", l: " + smoothLoss + ", h:(" + learningRate + ", " + seqLength + ", " + hiddenLayerSize + ")");
          break;
        }
      }

      INDArray inputs = getSequence(p);
      INDArray targets = getSequence(p + 1);

      // sample from the model now and then
      if (n % 1000 == 0 && n > 0) {
        for (int i = 0; i < 3; i++) {
          String txt = sample(inputs.getInt(0));
          System.out.printf("\n---\n %s \n----\n", txt);
        }
      }

      INDArray dWxh = Nd4j.zerosLike(wxh);
      INDArray dWxh2 = Nd4j.zerosLike(wxh2);
      INDArray dWhh = Nd4j.zerosLike(whh);
      INDArray dWhh2 = Nd4j.zerosLike(whh2);
      INDArray dWh2y = Nd4j.zerosLike(wh2y);

      INDArray dbh = Nd4j.zerosLike(bh);
      INDArray dbh2 = Nd4j.zerosLike(bh);
      INDArray dby = Nd4j.zerosLike(by);

      // forward seqLength characters through the net and fetch gradient
      double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby);
      double newLoss = smoothLoss * 0.999 + loss * 0.001;

      if (newLoss > smoothLoss) {
        learningRate *= 0.999f ;
      }
      smoothLoss = newLoss;
      if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
        System.out.println("loss is " + smoothLoss + "(" + loss + ") (over/underflow occurred, try adjusting hyperparameters)");
        break;
      }
      if (n % 100 == 0) {
        System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
      }

      if (n % batch == 0) {
        if (rmsProp) {
          // perform parameter update with RMSprop
          mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh));
          wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));

          mWxh2 = mWxh2.mul(decay).add(1 - decay).mul((dWxh2).mul(dWxh2));
          wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));

          mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh));
          whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));

          mWhh2 = mWhh2.mul(decay).add(1 - decay).mul((dWhh2).mul(dWhh2));
          whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));

          mbh2 = mbh2.mul(decay).add(1 - decay).mul((dbh2).mul(dbh2));
          bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));

          mWh2y = mWh2y.mul(decay).add(1 - decay).mul((dWh2y).mul(dWh2y));
          wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));

          mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh));
          bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));

          mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby));
          by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
        } else {
          // perform parameter update with Adagrad
          mWxh.addi(dWxh.mul(dWxh));
          wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));

          mWxh2.addi(dWxh2.mul(dWxh2));
          wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));

          mWhh.addi(dWhh.mul(dWhh));
          whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));

          mWhh2.addi(dWhh2.mul(dWhh2));
          whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));

          mbh2.addi(dbh2.mul(dbh2));
          bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));

          mWh2y.addi(dWh2y.mul(dWh2y));
          wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));

          mbh.addi(dbh.mul(dbh));
          bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));

          mby.addi(dby.mul(dby));
          by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
        }
      }

      p += seqLength; // move data pointer
      n++; // iteration counter
    }
  }