async function main()

in sentiment/train.js [160:244]


async function main() {
  const args = parseArguments();

  let tfn;
  if (args.gpu) {
    console.log('Using GPU for training');
    tfn = require('@tensorflow/tfjs-node-gpu');
  } else {
    console.log('Using CPU for training');
    tfn = require('@tensorflow/tfjs-node');
  }

  console.log('Loading data...');
  const multihot = args.modelType === 'multihot';
  const {xTrain, yTrain, xTest, yTest} =
      await loadData(args.numWords, args.maxLen, multihot);

  console.log('Building model...');
  const model = buildModel(
      args.modelType, args.maxLen, args.numWords, args.embeddingSize);

  model.compile({
    loss: 'binaryCrossentropy',
    optimizer: args.optimizer,
    metrics: ['acc']
  });
  model.summary();

  console.log('Training model...');
  await model.fit(xTrain, yTrain, {
    epochs: args.epochs,
    batchSize: args.batchSize,
    validationSplit: args.validationSplit,
    callbacks: args.logDir == null ? null : tfn.node.tensorBoard(args.logDir, {
      updateFreq: args.logUpdateFreq
    })
  });

  console.log('Evaluating model...');
  const [testLoss, testAcc] =
      model.evaluate(xTest, yTest, {batchSize: args.batchSize});
  console.log(`Evaluation loss: ${(await testLoss.data())[0].toFixed(4)}`);
  console.log(`Evaluation accuracy: ${(await testAcc.data())[0].toFixed(4)}`);

  // Save model.
  let metadata;
  if (args.modelSaveDir != null && args.modelSaveDir.length > 0) {
    if (multihot) {
      console.warn(
          'Skipping saving of multihot model, which is not supported.');
    } else {
      // Create base directory first.
      shelljs.mkdir('-p', args.modelSaveDir);

      // Load metadata template.
      console.log('Loading metadata template...');
      metadata = await loadMetadataTemplate();

      // Save metadata.
      metadata.epochs = args.epochs;
      metadata.embedding_size = args.embeddingSize;
      metadata.max_len = args.maxLen;
      metadata.model_type = args.modelType;
      metadata.batch_size = args.batchSize;
      metadata.vocabulary_size = args.numWords;
      const metadataPath = path.join(args.modelSaveDir, 'metadata.json');
      fs.writeFileSync(metadataPath, JSON.stringify(metadata));
      console.log(`Saved metadata to ${metadataPath}`);

      // Save model artifacts.
      await model.save(`file://${args.modelSaveDir}`);
      console.log(`Saved model to ${args.modelSaveDir}`);
    }
  }

  if (args.embeddingFilesPrefix != null &&
      args.embeddingFilesPrefix.length > 0) {
    if (metadata == null) {
      metadata = await loadMetadataTemplate();
    }
    await writeEmbeddingMatrixAndLabels(
        model, args.embeddingFilesPrefix, metadata.word_index,
        metadata.index_from);
  }
}