async train()

in speech-commands/src/browser_fft_recognizer.ts [1007:1064]


  async train(config?: TransferLearnConfig):
      Promise<tfl.History|[tfl.History, tfl.History]> {
    tf.util.assert(
        this.words != null && this.words.length > 0,
        () =>
            `Cannot train transfer-learning model '${this.name}' because no ` +
            `transfer learning example has been collected.`);
    tf.util.assert(
        this.words.length > 1,
        () => `Cannot train transfer-learning model '${
                  this.name}' because only ` +
            `1 word label ('${JSON.stringify(this.words)}') ` +
            `has been collected for transfer learning. Requires at least 2.`);
    if (config.fineTuningEpochs != null) {
      tf.util.assert(
          config.fineTuningEpochs >= 0 &&
              Number.isInteger(config.fineTuningEpochs),
          () => `If specified, fineTuningEpochs must be a non-negative ` +
              `integer, but received ${config.fineTuningEpochs}`);
    }

    if (config == null) {
      config = {};
    }

    if (this.model == null) {
      this.createTransferModelFromBaseModel();
    }

    // This layer needs to be frozen for the initial phase of the
    // transfer learning. During subsequent fine-tuning (if any), it will
    // be unfrozen.
    this.secondLastBaseDenseLayer.trainable = false;

    // Compile model for training.
    this.model.compile({
      loss: 'categoricalCrossentropy',
      optimizer: config.optimizer || 'sgd',
      metrics: ['acc']
    });

    // Use `tf.data.Dataset` objects for training of the total duration of
    // the recordings exceeds 60 seconds. Otherwise, use `tf.Tensor` objects.
    const datasetDurationMillisThreshold =
        config.fitDatasetDurationMillisThreshold == null ?
        60e3 :
        config.fitDatasetDurationMillisThreshold;
    if (this.dataset.durationMillis() > datasetDurationMillisThreshold) {
      console.log(
          `Detected large dataset: total duration = ` +
          `${this.dataset.durationMillis()} ms > ` +
          `${datasetDurationMillisThreshold} ms. ` +
          `Training transfer model using fitDataset() instead of fit()`);
      return this.trainOnDataset(config);
    } else {
      return this.trainOnTensors(config);
    }
  }