in speech-commands/src/browser_fft_recognizer.ts [1007:1064]
async train(config?: TransferLearnConfig):
Promise<tfl.History|[tfl.History, tfl.History]> {
tf.util.assert(
this.words != null && this.words.length > 0,
() =>
`Cannot train transfer-learning model '${this.name}' because no ` +
`transfer learning example has been collected.`);
tf.util.assert(
this.words.length > 1,
() => `Cannot train transfer-learning model '${
this.name}' because only ` +
`1 word label ('${JSON.stringify(this.words)}') ` +
`has been collected for transfer learning. Requires at least 2.`);
if (config.fineTuningEpochs != null) {
tf.util.assert(
config.fineTuningEpochs >= 0 &&
Number.isInteger(config.fineTuningEpochs),
() => `If specified, fineTuningEpochs must be a non-negative ` +
`integer, but received ${config.fineTuningEpochs}`);
}
if (config == null) {
config = {};
}
if (this.model == null) {
this.createTransferModelFromBaseModel();
}
// This layer needs to be frozen for the initial phase of the
// transfer learning. During subsequent fine-tuning (if any), it will
// be unfrozen.
this.secondLastBaseDenseLayer.trainable = false;
// Compile model for training.
this.model.compile({
loss: 'categoricalCrossentropy',
optimizer: config.optimizer || 'sgd',
metrics: ['acc']
});
// Use `tf.data.Dataset` objects for training of the total duration of
// the recordings exceeds 60 seconds. Otherwise, use `tf.Tensor` objects.
const datasetDurationMillisThreshold =
config.fitDatasetDurationMillisThreshold == null ?
60e3 :
config.fitDatasetDurationMillisThreshold;
if (this.dataset.durationMillis() > datasetDurationMillisThreshold) {
console.log(
`Detected large dataset: total duration = ` +
`${this.dataset.durationMillis()} ms > ` +
`${datasetDurationMillisThreshold} ms. ` +
`Training transfer model using fitDataset() instead of fit()`);
return this.trainOnDataset(config);
} else {
return this.trainOnTensors(config);
}
}