in 10_mlops/model.py [0:0]
def train_and_evaluate(train_data_pattern, eval_data_pattern, test_data_pattern, export_dir, output_dir):
train_batch_size = TRAIN_BATCH_SIZE
if DEVELOP_MODE:
eval_batch_size = 100
steps_per_epoch = 3
epochs = 2
num_eval_examples = eval_batch_size * 10
else:
eval_batch_size = 100
steps_per_epoch = NUM_EXAMPLES // train_batch_size
epochs = NUM_EPOCHS
num_eval_examples = eval_batch_size * 100
train_dataset = read_dataset(train_data_pattern, train_batch_size)
eval_dataset = read_dataset(eval_data_pattern, eval_batch_size, tf.estimator.ModeKeys.EVAL, num_eval_examples)
# checkpoint
checkpoint_path = '{}/checkpoints/flights.cpt'.format(output_dir)
logging.info("Checkpointing to {}".format(checkpoint_path))
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=1)
# call back to write out hyperparameter tuning metric
METRIC = 'val_rmse'
hpt = hypertune.HyperTune()
class HpCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if logs and METRIC in logs:
logging.info("Epoch {}: {} = {}".format(epoch, METRIC, logs[METRIC]))
hpt.report_hyperparameter_tuning_metric(hyperparameter_metric_tag=METRIC,
metric_value=logs[METRIC],
global_step=epoch)
# train the model
model = create_model()
logging.info(f"Training on {train_data_pattern}; eval on {eval_data_pattern}; {epochs} epochs; {steps_per_epoch}")
history = model.fit(train_dataset,
validation_data=eval_dataset,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
callbacks=[cp_callback, HpCallback()])
# export
logging.info('Exporting to {}'.format(export_dir))
tf.saved_model.save(model, export_dir)
# write out final metric
final_rmse = history.history[METRIC][-1]
logging.info("Validation metric {} on {} samples = {}".format(METRIC, num_eval_examples, final_rmse))
if (not DEVELOP_MODE) and (test_data_pattern is not None) and (not SKIP_FULL_EVAL):
logging.info("Evaluating over full test dataset")
test_dataset = read_dataset(test_data_pattern, eval_batch_size, tf.estimator.ModeKeys.EVAL, None)
final_metrics = model.evaluate(test_dataset)
logging.info("Final metrics on full test dataset = {}".format(final_metrics))
else:
logging.info("Skipping evaluation on full test dataset")