in addition-rnn/index.js [257:342]
async train(iterations, batchSize, numTestExamples) {
const lossValues = [[], []];
const accuracyValues = [[], []];
for (let i = 0; i < iterations; ++i) {
const beginMs = performance.now();
const history = await this.model.fit(this.trainXs, this.trainYs, {
epochs: 1,
batchSize,
validationData: [this.testXs, this.testYs],
yieldEvery: 'epoch'
});
const elapsedMs = performance.now() - beginMs;
const modelFitTime = elapsedMs / 1000;
const trainLoss = history.history['loss'][0];
const trainAccuracy = history.history['acc'][0];
const valLoss = history.history['val_loss'][0];
const valAccuracy = history.history['val_acc'][0];
lossValues[0].push({'x': i, 'y': trainLoss});
lossValues[1].push({'x': i, 'y': valLoss});
accuracyValues[0].push({'x': i, 'y': trainAccuracy});
accuracyValues[1].push({'x': i, 'y': valAccuracy});
document.getElementById('trainStatus').textContent =
`Iteration ${i + 1} of ${iterations}: ` +
`Time per iteration: ${modelFitTime.toFixed(3)} (seconds)`;
const lossContainer = document.getElementById('lossChart');
tfvis.render.linechart(
lossContainer, {values: lossValues, series: ['train', 'validation']},
{
width: 420,
height: 300,
xLabel: 'epoch',
yLabel: 'loss',
});
const accuracyContainer = document.getElementById('accuracyChart');
tfvis.render.linechart(
accuracyContainer,
{values: accuracyValues, series: ['train', 'validation']}, {
width: 420,
height: 300,
xLabel: 'epoch',
yLabel: 'accuracy',
});
if (this.testXsForDisplay == null ||
this.testXsForDisplay.shape[0] !== numTestExamples) {
if (this.textXsForDisplay) {
this.textXsForDisplay.dispose();
}
this.testXsForDisplay = this.testXs.slice(
[0, 0, 0],
[numTestExamples, this.testXs.shape[1], this.testXs.shape[2]]);
}
const examples = [];
const isCorrect = [];
tf.tidy(() => {
const predictOut = this.model.predict(this.testXsForDisplay);
for (let k = 0; k < numTestExamples; ++k) {
const scores =
predictOut
.slice(
[k, 0, 0], [1, predictOut.shape[1], predictOut.shape[2]])
.as2D(predictOut.shape[1], predictOut.shape[2]);
const decoded = this.charTable.decode(scores);
examples.push(this.testData[k][0] + ' = ' + decoded);
isCorrect.push(this.testData[k][1].trim() === decoded.trim());
}
});
const examplesDiv = document.getElementById('testExamples');
const examplesContent = examples.map(
(example, i) =>
`<div class="${
isCorrect[i] ? 'answer-correct' : 'answer-wrong'}">` +
`${example}` +
`</div>`);
examplesDiv.innerHTML = examplesContent.join('\n');
}
}