in tensorflow_ranking/python/keras/pipeline.py [0:0]
def train_and_validate(self, verbose=0):
"""Main function to train the model with TPU strategy.
Example usage:
```python
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
"utility": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss="softmax_loss")
model_builder = SimpleModelBuilder(
context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams)
pipeline = BasicModelFitPipeline(
model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
```
Args:
verbose: An int for the verbosity level.
"""
strategy = self._strategy
with strategy_utils.strategy_scope(strategy):
model = self._model_builder.build()
# Note that all losses and metrics need to be constructed within the
# strategy scope. This is why we use member function like `build_loss` and
# don't use passed-in objects.
model.compile(
optimizer=self._optimizer,
loss=self.build_loss(),
metrics=self.build_metrics(),
loss_weights=self._hparams.loss_weights,
weighted_metrics=(self.build_weighted_metrics()
if self._hparams.use_weighted_metrics else None),
steps_per_execution=self._hparams.steps_per_execution)
# Move the following out of strategy.scope only after b/173547275 fixed.
# Otherwise, MultiWorkerMirroredStrategy will fail.
train_dataset, valid_dataset = (
self._dataset_builder.build_train_dataset(),
self._dataset_builder.build_valid_dataset())
model.fit(
x=train_dataset,
epochs=self._hparams.num_epochs,
steps_per_epoch=self._hparams.steps_per_epoch,
validation_steps=self._hparams.validation_steps,
validation_data=valid_dataset,
callbacks=self.build_callbacks(),
verbose=verbose)
model_output_dir = strategy_utils.get_output_filepath(
self._hparams.model_dir, strategy)
self.export_saved_model(
model,
export_to=os.path.join(model_output_dir, "export/latest_model"))
if self._hparams.export_best_model:
best_checkpoint = tf.train.latest_checkpoint(
os.path.join(self._hparams.model_dir, "best_checkpoint"))
if best_checkpoint:
self.export_saved_model(
model,
export_to=os.path.join(model_output_dir,
"export/best_model_by_metric"),
checkpoint=best_checkpoint)
else:
raise ValueError("Didn't find the best checkpoint.")