in tfjs-layers/src/engine/training_dataset.ts [301:500]
export async function fitDataset<T>(
// Type `model` as `any` here to avoid circular dependency w/
// training.ts.
// tslint:disable-next-line:no-any
model: any, dataset: Dataset<T>,
args: ModelFitDatasetArgs<T>): Promise<History> {
const hasBatchesPerEpoch = args.batchesPerEpoch != null;
tfc.util.assert(
model.optimizer != null,
() => 'You must compile a model before training/testing. Use ' +
'LayersModel.compile(modelCompileConfig).');
tfc.util.assert(
args != null,
() => `For fitDataset(), the 2nd argument (config) is required, ` +
`but it is not provided in this call.`);
tfc.util.assert(
args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs),
() => `For fitDataset(), config.epochs is expected to be a positive ` +
`integer, but got ${args.epochs}`);
tfc.util.assert(
!hasBatchesPerEpoch ||
(args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)),
() => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
`positive integer if specified, but got ${args.batchesPerEpoch}`);
tfc.util.assert(
// tslint:disable-next-line:no-any
(args as any)['validationSplit'] == null,
() => '`validationSplit` is not supported by `fitDataset()`. ' +
'Use validationData instead.');
if (model.isTraining) {
throw new Error(
'Cannot start training because another fit() call is ongoing.');
}
model.isTraining = true;
try {
const doValidation = args.validationData != null;
let valXs: tfc.Tensor|tfc.Tensor[];
let valYs: tfc.Tensor|tfc.Tensor[];
if (doValidation) {
if (isDatasetObject(args.validationData)) {
tfc.util.assert(
args.validationBatches == null ||
(args.validationBatches > 0 &&
Number.isInteger(args.validationBatches)),
() => `For fitDataset() with dataset-based validation, ` +
`config.validationBatches is expected not to be provided, ` +
`or to be a positive integer, ` +
`but got ${args.validationBatches}`);
} else {
const validationData = standardizeTensorValidationData(
args.validationData as
[tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[]] |
[
tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[],
tfc.Tensor | tfc.Tensor[]
]);
valXs = validationData.xs;
valYs = validationData.ys;
}
}
const trainFunction = model.makeTrainFunction();
const outLabels = model.getDedupedMetricsNames() as string[];
let callbackMetrics: string[];
if (doValidation) {
callbackMetrics =
outLabels.slice().concat(outLabels.map(n => 'val_' + n));
} else {
callbackMetrics = outLabels.slice();
}
const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
const verbose = args.verbose == null ? 1 : args.verbose;
const {callbackList, history} = configureCallbacks(
callbacks, verbose, args.epochs, null, null,
getStepsPerEpoch(dataset, args),
null, // Batch size determined by the dataset itself.
doValidation, callbackMetrics);
callbackList.setModel(model);
model.history = history;
await callbackList.onTrainBegin();
model.stopTraining_ = false;
let epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
let dataIterator = await dataset.iterator();
while (epoch < args.epochs) {
const epochLogs: UnresolvedLogs = {};
await callbackList.onEpochBegin(epoch);
let stepsDone = 0;
let batchIndex = 0;
if (!hasBatchesPerEpoch) {
dataIterator = await dataset.iterator();
}
while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) {
const iteratorOut = await dataIterator.next();
// If `batchesPerEpoch` is specified, the dataset should not be
// exhausted until all epoches are done.
if (hasBatchesPerEpoch && iteratorOut.done) {
console.warn(
'You provided `batchesPerEpoch` as ' +
`${args.batchesPerEpoch}, ` +
'but your dataset iterator ran out of data after ' +
`${stepsDone} batches; ` +
'interrupting training. Make sure that your ' +
'dataset can generate at least `batchesPerEpoch * epochs` ' +
'batches (in this case, ' +
`${args.batchesPerEpoch * args.epochs} batches). ` +
'You may need to use the repeat() function when building ' +
'your dataset.');
break;
}
if (iteratorOut.value != null) {
const {xs, ys} =
standardizeDataIteratorOutput(model, iteratorOut.value);
const batchLogs: UnresolvedLogs = {};
batchLogs['batch'] = batchIndex;
batchLogs['size'] = xs[0].shape[0];
await callbackList.onBatchBegin(batchIndex, batchLogs);
const sampleWeights: tfc.Tensor[] = [];
if (args.classWeight != null) {
const standardClassWeights =
standardizeClassWeights(args.classWeight, model.outputNames);
for (let i = 0; i < standardClassWeights.length; ++i) {
sampleWeights.push(await standardizeWeights(
ys[i], null, standardClassWeights[i]));
}
}
// Train on batch.
const ins = xs.concat(ys).concat(sampleWeights);
const outs = trainFunction(ins);
tfc.dispose(ins);
for (let i = 0; i < outLabels.length; ++i) {
const label = outLabels[i];
const out = outs[i];
batchLogs[label] = out;
tfc.keep(out);
}
await callbackList.onBatchEnd(batchIndex, batchLogs);
disposeTensorsInLogs(batchLogs);
batchIndex++;
stepsDone++;
}
if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch :
iteratorOut.done) {
// Epoch finished. Perform validation.
if (doValidation) {
let valOuts: tfc.Scalar[];
if (isDatasetObject(args.validationData)) {
valOuts = toList(await model.evaluateDataset(
args.validationData, {batches: args.validationBatches}));
} else {
valOuts = toList(model.evaluate(valXs, valYs, {
batchSize: args.validationBatchSize == null ?
DEFAULT_VALIDATION_BATCH_SIZE :
args.validationBatchSize,
verbose: 0
}));
}
for (let i = 0; i < model.metricsNames.length; ++i) {
epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i];
}
}
// Call `break` to exit one epoch lopp after validation is done. If
// config.batchesPerEpoch is specified, an epoch while loop will
// stop when `stepsDone >= config.batchesPerEpoch`. When
// config.batchesPerEpoch is not provided, the following `break` is
// required to exit the while lopp after dataset is exhausted.
break;
}
if (model.stopTraining_) {
break;
}
}
await callbackList.onEpochEnd(epoch, epochLogs);
epoch++;
if (model.stopTraining_) {
break;
}
}
await callbackList.onTrainEnd();
await model.history.syncData();
return model.history;
} finally {
model.isTraining = false;
}
}