async train()

in addition-rnn/index.js [257:342]


  async train(iterations, batchSize, numTestExamples) {
    const lossValues = [[], []];
    const accuracyValues = [[], []];
    for (let i = 0; i < iterations; ++i) {
      const beginMs = performance.now();
      const history = await this.model.fit(this.trainXs, this.trainYs, {
        epochs: 1,
        batchSize,
        validationData: [this.testXs, this.testYs],
        yieldEvery: 'epoch'
      });

      const elapsedMs = performance.now() - beginMs;
      const modelFitTime = elapsedMs / 1000;

      const trainLoss = history.history['loss'][0];
      const trainAccuracy = history.history['acc'][0];
      const valLoss = history.history['val_loss'][0];
      const valAccuracy = history.history['val_acc'][0];

      lossValues[0].push({'x': i, 'y': trainLoss});
      lossValues[1].push({'x': i, 'y': valLoss});

      accuracyValues[0].push({'x': i, 'y': trainAccuracy});
      accuracyValues[1].push({'x': i, 'y': valAccuracy});

      document.getElementById('trainStatus').textContent =
          `Iteration ${i + 1} of ${iterations}: ` +
          `Time per iteration: ${modelFitTime.toFixed(3)} (seconds)`;
      const lossContainer = document.getElementById('lossChart');
      tfvis.render.linechart(
          lossContainer, {values: lossValues, series: ['train', 'validation']},
          {
            width: 420,
            height: 300,
            xLabel: 'epoch',
            yLabel: 'loss',
          });

      const accuracyContainer = document.getElementById('accuracyChart');
      tfvis.render.linechart(
          accuracyContainer,
          {values: accuracyValues, series: ['train', 'validation']}, {
            width: 420,
            height: 300,
            xLabel: 'epoch',
            yLabel: 'accuracy',
          });

      if (this.testXsForDisplay == null ||
          this.testXsForDisplay.shape[0] !== numTestExamples) {
        if (this.textXsForDisplay) {
          this.textXsForDisplay.dispose();
        }
        this.testXsForDisplay = this.testXs.slice(
            [0, 0, 0],
            [numTestExamples, this.testXs.shape[1], this.testXs.shape[2]]);
      }

      const examples = [];
      const isCorrect = [];
      tf.tidy(() => {
        const predictOut = this.model.predict(this.testXsForDisplay);
        for (let k = 0; k < numTestExamples; ++k) {
          const scores =
              predictOut
                  .slice(
                      [k, 0, 0], [1, predictOut.shape[1], predictOut.shape[2]])
                  .as2D(predictOut.shape[1], predictOut.shape[2]);
          const decoded = this.charTable.decode(scores);
          examples.push(this.testData[k][0] + ' = ' + decoded);
          isCorrect.push(this.testData[k][1].trim() === decoded.trim());
        }
      });

      const examplesDiv = document.getElementById('testExamples');
      const examplesContent = examples.map(
          (example, i) =>
              `<div class="${
                  isCorrect[i] ? 'answer-correct' : 'answer-wrong'}">` +
              `${example}` +
              `</div>`);

      examplesDiv.innerHTML = examplesContent.join('\n');
    }
  }