export async function fitDataset()

in tfjs-layers/src/engine/training_dataset.ts [301:500]


export async function fitDataset<T>(
    // Type `model` as `any` here to avoid circular dependency w/
    // training.ts.
    // tslint:disable-next-line:no-any
    model: any, dataset: Dataset<T>,
    args: ModelFitDatasetArgs<T>): Promise<History> {
  const hasBatchesPerEpoch = args.batchesPerEpoch != null;
  tfc.util.assert(
      model.optimizer != null,
      () => 'You must compile a model before training/testing. Use ' +
          'LayersModel.compile(modelCompileConfig).');

  tfc.util.assert(
      args != null,
      () => `For fitDataset(), the 2nd argument (config) is required, ` +
          `but it is not provided in this call.`);
  tfc.util.assert(
      args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs),
      () => `For fitDataset(), config.epochs is expected to be a positive ` +
          `integer, but got ${args.epochs}`);
  tfc.util.assert(
      !hasBatchesPerEpoch ||
          (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)),
      () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
          `positive integer if specified, but got ${args.batchesPerEpoch}`);
  tfc.util.assert(
      // tslint:disable-next-line:no-any
      (args as any)['validationSplit'] == null,
      () => '`validationSplit` is not supported by `fitDataset()`. ' +
          'Use validationData instead.');

  if (model.isTraining) {
    throw new Error(
        'Cannot start training because another fit() call is ongoing.');
  }
  model.isTraining = true;

  try {
    const doValidation = args.validationData != null;
    let valXs: tfc.Tensor|tfc.Tensor[];
    let valYs: tfc.Tensor|tfc.Tensor[];
    if (doValidation) {
      if (isDatasetObject(args.validationData)) {
        tfc.util.assert(
            args.validationBatches == null ||
                (args.validationBatches > 0 &&
                 Number.isInteger(args.validationBatches)),
            () => `For fitDataset() with dataset-based validation, ` +
                `config.validationBatches is expected not to be provided, ` +
                `or to be a positive integer, ` +
                `but got ${args.validationBatches}`);
      } else {
        const validationData = standardizeTensorValidationData(
            args.validationData as
                    [tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[]] |
            [
              tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[],
              tfc.Tensor | tfc.Tensor[]
            ]);
        valXs = validationData.xs;
        valYs = validationData.ys;
      }
    }

    const trainFunction = model.makeTrainFunction();
    const outLabels = model.getDedupedMetricsNames() as string[];

    let callbackMetrics: string[];
    if (doValidation) {
      callbackMetrics =
          outLabels.slice().concat(outLabels.map(n => 'val_' + n));
    } else {
      callbackMetrics = outLabels.slice();
    }

    const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
    const verbose = args.verbose == null ? 1 : args.verbose;
    const {callbackList, history} = configureCallbacks(
        callbacks, verbose, args.epochs, null, null,
        getStepsPerEpoch(dataset, args),
        null,  // Batch size determined by the dataset itself.
        doValidation, callbackMetrics);
    callbackList.setModel(model);
    model.history = history;

    await callbackList.onTrainBegin();
    model.stopTraining_ = false;
    let epoch = args.initialEpoch == null ? 0 : args.initialEpoch;

    let dataIterator = await dataset.iterator();
    while (epoch < args.epochs) {
      const epochLogs: UnresolvedLogs = {};
      await callbackList.onEpochBegin(epoch);
      let stepsDone = 0;
      let batchIndex = 0;
      if (!hasBatchesPerEpoch) {
        dataIterator = await dataset.iterator();
      }
      while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) {
        const iteratorOut = await dataIterator.next();

        // If `batchesPerEpoch` is specified, the dataset should not be
        // exhausted until all epoches are done.
        if (hasBatchesPerEpoch && iteratorOut.done) {
          console.warn(
              'You provided `batchesPerEpoch` as ' +
              `${args.batchesPerEpoch}, ` +
              'but your dataset iterator ran out of data after ' +
              `${stepsDone} batches; ` +
              'interrupting training. Make sure that your ' +
              'dataset can generate at least `batchesPerEpoch * epochs` ' +
              'batches (in this case, ' +
              `${args.batchesPerEpoch * args.epochs} batches). ` +
              'You may need to use the repeat() function when building ' +
              'your dataset.');
          break;
        }

        if (iteratorOut.value != null) {
          const {xs, ys} =
              standardizeDataIteratorOutput(model, iteratorOut.value);
          const batchLogs: UnresolvedLogs = {};
          batchLogs['batch'] = batchIndex;
          batchLogs['size'] = xs[0].shape[0];

          await callbackList.onBatchBegin(batchIndex, batchLogs);

          const sampleWeights: tfc.Tensor[] = [];
          if (args.classWeight != null) {
            const standardClassWeights =
                standardizeClassWeights(args.classWeight, model.outputNames);
            for (let i = 0; i < standardClassWeights.length; ++i) {
              sampleWeights.push(await standardizeWeights(
                  ys[i], null, standardClassWeights[i]));
            }
          }

          // Train on batch.
          const ins = xs.concat(ys).concat(sampleWeights);
          const outs = trainFunction(ins);
          tfc.dispose(ins);
          for (let i = 0; i < outLabels.length; ++i) {
            const label = outLabels[i];
            const out = outs[i];
            batchLogs[label] = out;
            tfc.keep(out);
          }

          await callbackList.onBatchEnd(batchIndex, batchLogs);
          disposeTensorsInLogs(batchLogs);

          batchIndex++;
          stepsDone++;
        }

        if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch :
                                 iteratorOut.done) {
          // Epoch finished. Perform validation.
          if (doValidation) {
            let valOuts: tfc.Scalar[];
            if (isDatasetObject(args.validationData)) {
              valOuts = toList(await model.evaluateDataset(
                  args.validationData, {batches: args.validationBatches}));
            } else {
              valOuts = toList(model.evaluate(valXs, valYs, {
                batchSize: args.validationBatchSize == null ?
                    DEFAULT_VALIDATION_BATCH_SIZE :
                    args.validationBatchSize,
                verbose: 0
              }));
            }
            for (let i = 0; i < model.metricsNames.length; ++i) {
              epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i];
            }
          }
          // Call `break` to exit one epoch lopp after validation is done. If
          // config.batchesPerEpoch is specified, an epoch while loop will
          // stop when `stepsDone >= config.batchesPerEpoch`. When
          // config.batchesPerEpoch is not provided, the following `break` is
          // required to exit the while lopp after dataset is exhausted.
          break;
        }

        if (model.stopTraining_) {
          break;
        }
      }
      await callbackList.onEpochEnd(epoch, epochLogs);
      epoch++;
      if (model.stopTraining_) {
        break;
      }
    }
    await callbackList.onTrainEnd();
    await model.history.syncData();
    return model.history;
  } finally {
    model.isTraining = false;
  }
}