async function main()

in translation/translation.ts [353:482]


async function main () {
  let tfn;
  if (args.gpu) {
    console.log('Using GPU');
    tfn = require('@tensorflow/tfjs-node-gpu');
  } else {
    console.log('Using CPU');
    tfn = require('@tensorflow/tfjs-node');
  }

  const {
    inputTexts,
    maxDecoderSeqLength,
    numEncoderTokens,
    numDecoderTokens,
    targetTokenIndex,
    encoderInputData,
    decoderInputData,
    decoderTargetData,
  } = await readData(args.data_path);

  const {
    encoderInputs,
    encoderStates,
    decoderInputs,
    decoderLstm,
    decoderDense,
    model,
  } = seq2seqModel(numEncoderTokens, numDecoderTokens, args.latent_dim);

  // Run training.
  model.compile({
    optimizer: 'rmsprop',
    loss: 'categoricalCrossentropy',
  });
  model.summary();

  if (args.logDir != null) {
    console.log(
      `To view logs in tensorboard, do:\n` +
      `  tensorboard --logdir ${args.logDir}\n`);
  }

  await model.fit(
    [encoderInputData, decoderInputData], decoderTargetData, {
      batchSize: args.batch_size,
      epochs: args.epochs,
      validationSplit: 0.2,
      callbacks: args.logDir == null ? null :
          tfn.node.tensorBoard(args.logDir, {
            updateFreq: args.logUpdateFreq
          })
    }
  );

  await model.save(`file://${args.artifacts_dir}`);

  // tfjs.converters.save_keras_model(model, FLAGS.artifacts_dir)

  // Next: inference mode (sampling).
  // Here's the drill:
  // 1) encode input and retrieve initial decoder state
  // 2) run one step of decoder with this initial state
  // and a "start of sequence" token as target.
  // Output will be the next target token
  // 3) Repeat with the current target token and current states

  // Define sampling models
  const encoderModel = tf.model({
    inputs: encoderInputs,
    outputs: encoderStates,
    name: 'encoderModel',
  });

  const decoderStateInputH = tf.layers.input({
    shape: [args.latent_dim],
    name: 'decoderStateInputHidden',
  });
  const decoderStateInputC = tf.layers.input({
    shape: args.latent_dim,
    name: 'decoderStateInputCell',
  });
  const decoderStatesInputs = [decoderStateInputH, decoderStateInputC];
  let [decoderOutputs, stateH, stateC] = decoderLstm.apply(
      [decoderInputs, ...decoderStatesInputs]
  ) as tf.SymbolicTensor[];

  const decoderStates = [stateH, stateC];
  decoderOutputs = decoderDense.apply(decoderOutputs) as tf.SymbolicTensor;
  const decoderModel = tf.model({
    inputs: [decoderInputs, ...decoderStatesInputs],
    outputs: [decoderOutputs, ...decoderStates],
    name: 'decoderModel',
  });

  // Reverse-lookup token index to decode sequences back to
  // something readable.
  const reverseTargetCharIndex =
      invertKv(targetTokenIndex) as {[indice: number]: string};

  const targetBeginIndex = targetTokenIndex['\t'];

  for (let seqIndex = 0; seqIndex < args.num_test_sentences; seqIndex++) {
    // Take one sequence (part of the training set)
    // for trying out decoding.
    const inputSeq = encoderInputData.slice(seqIndex, 1);

    // Get expected output
    const targetSeqVoc =
        decoderTargetData.slice(seqIndex, 1).squeeze([0]) as tf.Tensor2D;
    const targetSeqTensor = targetSeqVoc.argMax(-1) as tf.Tensor1D;

    const targetSeqList = await targetSeqTensor.array();

    // One-hot to index
    const targetSeq =
        targetSeqList.map(indice => reverseTargetCharIndex[indice]);

    // Array to string
    const targetSeqStr = targetSeq.join('').replace('\n', '');
    const decodedSentence = await decodeSequence(
      inputSeq, encoderModel, decoderModel, numDecoderTokens,
      targetBeginIndex, reverseTargetCharIndex, maxDecoderSeqLength,
    );
    console.log('-');
    console.log('Input sentence:', inputTexts[seqIndex]);
    console.log('Target sentence:', targetSeqStr);
    console.log('Decoded sentence:', decodedSentence);
  }
}