def train_and_validate()

in tensorflow_ranking/python/keras/pipeline.py [0:0]


  def train_and_validate(self, verbose=0):
    """Main function to train the model with TPU strategy.

    Example usage:

    ```python
    context_feature_spec = {}
    example_feature_spec = {
        "example_feature_1": tf.io.FixedLenFeature(
            shape=(1,), dtype=tf.float32, default_value=0.0)
    }
    mask_feature_name = "list_mask"
    label_spec = {
        "utility": tf.io.FixedLenFeature(
            shape=(1,), dtype=tf.float32, default_value=0.0)
    }
    dataset_hparams = DatasetHparams(
        train_input_pattern="train.dat",
        valid_input_pattern="valid.dat",
        train_batch_size=128,
        valid_batch_size=128)
    pipeline_hparams = pipeline.PipelineHparams(
        model_dir="model/",
        num_epochs=2,
        steps_per_epoch=5,
        validation_steps=2,
        learning_rate=0.01,
        loss="softmax_loss")
    model_builder = SimpleModelBuilder(
        context_feature_spec, example_feature_spec, mask_feature_name)
    dataset_builder = SimpleDatasetBuilder(
        context_feature_spec,
        example_feature_spec,
        mask_feature_name,
        label_spec,
        dataset_hparams)
    pipeline = BasicModelFitPipeline(
        model_builder, dataset_builder, pipeline_hparams)
    pipeline.train_and_validate(verbose=1)
    ```

    Args:
      verbose: An int for the verbosity level.
    """
    strategy = self._strategy
    with strategy_utils.strategy_scope(strategy):
      model = self._model_builder.build()
      # Note that all losses and metrics need to be constructed within the
      # strategy scope. This is why we use member function like `build_loss` and
      # don't use passed-in objects.
      model.compile(
          optimizer=self._optimizer,
          loss=self.build_loss(),
          metrics=self.build_metrics(),
          loss_weights=self._hparams.loss_weights,
          weighted_metrics=(self.build_weighted_metrics()
                            if self._hparams.use_weighted_metrics else None),
          steps_per_execution=self._hparams.steps_per_execution)

      # Move the following out of strategy.scope only after b/173547275 fixed.
      # Otherwise, MultiWorkerMirroredStrategy will fail.
      train_dataset, valid_dataset = (
          self._dataset_builder.build_train_dataset(),
          self._dataset_builder.build_valid_dataset())
      model.fit(
          x=train_dataset,
          epochs=self._hparams.num_epochs,
          steps_per_epoch=self._hparams.steps_per_epoch,
          validation_steps=self._hparams.validation_steps,
          validation_data=valid_dataset,
          callbacks=self.build_callbacks(),
          verbose=verbose)

      model_output_dir = strategy_utils.get_output_filepath(
          self._hparams.model_dir, strategy)
      self.export_saved_model(
          model,
          export_to=os.path.join(model_output_dir, "export/latest_model"))

      if self._hparams.export_best_model:
        best_checkpoint = tf.train.latest_checkpoint(
            os.path.join(self._hparams.model_dir, "best_checkpoint"))
        if best_checkpoint:
          self.export_saved_model(
              model,
              export_to=os.path.join(model_output_dir,
                                     "export/best_model_by_metric"),
              checkpoint=best_checkpoint)
        else:
          raise ValueError("Didn't find the best checkpoint.")