in translation/translation.ts [353:482]
async function main () {
let tfn;
if (args.gpu) {
console.log('Using GPU');
tfn = require('@tensorflow/tfjs-node-gpu');
} else {
console.log('Using CPU');
tfn = require('@tensorflow/tfjs-node');
}
const {
inputTexts,
maxDecoderSeqLength,
numEncoderTokens,
numDecoderTokens,
targetTokenIndex,
encoderInputData,
decoderInputData,
decoderTargetData,
} = await readData(args.data_path);
const {
encoderInputs,
encoderStates,
decoderInputs,
decoderLstm,
decoderDense,
model,
} = seq2seqModel(numEncoderTokens, numDecoderTokens, args.latent_dim);
// Run training.
model.compile({
optimizer: 'rmsprop',
loss: 'categoricalCrossentropy',
});
model.summary();
if (args.logDir != null) {
console.log(
`To view logs in tensorboard, do:\n` +
` tensorboard --logdir ${args.logDir}\n`);
}
await model.fit(
[encoderInputData, decoderInputData], decoderTargetData, {
batchSize: args.batch_size,
epochs: args.epochs,
validationSplit: 0.2,
callbacks: args.logDir == null ? null :
tfn.node.tensorBoard(args.logDir, {
updateFreq: args.logUpdateFreq
})
}
);
await model.save(`file://${args.artifacts_dir}`);
// tfjs.converters.save_keras_model(model, FLAGS.artifacts_dir)
// Next: inference mode (sampling).
// Here's the drill:
// 1) encode input and retrieve initial decoder state
// 2) run one step of decoder with this initial state
// and a "start of sequence" token as target.
// Output will be the next target token
// 3) Repeat with the current target token and current states
// Define sampling models
const encoderModel = tf.model({
inputs: encoderInputs,
outputs: encoderStates,
name: 'encoderModel',
});
const decoderStateInputH = tf.layers.input({
shape: [args.latent_dim],
name: 'decoderStateInputHidden',
});
const decoderStateInputC = tf.layers.input({
shape: args.latent_dim,
name: 'decoderStateInputCell',
});
const decoderStatesInputs = [decoderStateInputH, decoderStateInputC];
let [decoderOutputs, stateH, stateC] = decoderLstm.apply(
[decoderInputs, ...decoderStatesInputs]
) as tf.SymbolicTensor[];
const decoderStates = [stateH, stateC];
decoderOutputs = decoderDense.apply(decoderOutputs) as tf.SymbolicTensor;
const decoderModel = tf.model({
inputs: [decoderInputs, ...decoderStatesInputs],
outputs: [decoderOutputs, ...decoderStates],
name: 'decoderModel',
});
// Reverse-lookup token index to decode sequences back to
// something readable.
const reverseTargetCharIndex =
invertKv(targetTokenIndex) as {[indice: number]: string};
const targetBeginIndex = targetTokenIndex['\t'];
for (let seqIndex = 0; seqIndex < args.num_test_sentences; seqIndex++) {
// Take one sequence (part of the training set)
// for trying out decoding.
const inputSeq = encoderInputData.slice(seqIndex, 1);
// Get expected output
const targetSeqVoc =
decoderTargetData.slice(seqIndex, 1).squeeze([0]) as tf.Tensor2D;
const targetSeqTensor = targetSeqVoc.argMax(-1) as tf.Tensor1D;
const targetSeqList = await targetSeqTensor.array();
// One-hot to index
const targetSeq =
targetSeqList.map(indice => reverseTargetCharIndex[indice]);
// Array to string
const targetSeqStr = targetSeq.join('').replace('\n', '');
const decodedSentence = await decodeSequence(
inputSeq, encoderModel, decoderModel, numDecoderTokens,
targetBeginIndex, reverseTargetCharIndex, maxDecoderSeqLength,
);
console.log('-');
console.log('Input sentence:', inputTexts[seqIndex]);
console.log('Target sentence:', targetSeqStr);
console.log('Decoded sentence:', decodedSentence);
}
}