def train_and_evaluate()

in 10_mlops/model.py [0:0]


def train_and_evaluate(train_data_pattern, eval_data_pattern, test_data_pattern, export_dir, output_dir):
    train_batch_size = TRAIN_BATCH_SIZE
    if DEVELOP_MODE:
        eval_batch_size = 100
        steps_per_epoch = 3
        epochs = 2
        num_eval_examples = eval_batch_size * 10
    else:
        eval_batch_size = 100
        steps_per_epoch = NUM_EXAMPLES // train_batch_size
        epochs = NUM_EPOCHS
        num_eval_examples = eval_batch_size * 100

    train_dataset = read_dataset(train_data_pattern, train_batch_size)
    eval_dataset = read_dataset(eval_data_pattern, eval_batch_size, tf.estimator.ModeKeys.EVAL, num_eval_examples)

    # checkpoint
    checkpoint_path = '{}/checkpoints/flights.cpt'.format(output_dir)
    logging.info("Checkpointing to {}".format(checkpoint_path))
    cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1)

    # call back to write out hyperparameter tuning metric
    METRIC = 'val_rmse'
    hpt = hypertune.HyperTune()

    class HpCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            if logs and METRIC in logs:
                logging.info("Epoch {}: {} = {}".format(epoch, METRIC, logs[METRIC]))
                hpt.report_hyperparameter_tuning_metric(hyperparameter_metric_tag=METRIC,
                                                        metric_value=logs[METRIC],
                                                        global_step=epoch)

    # train the model
    model = create_model()
    logging.info(f"Training on {train_data_pattern}; eval on {eval_data_pattern}; {epochs} epochs; {steps_per_epoch}")
    history = model.fit(train_dataset,
                        validation_data=eval_dataset,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=[cp_callback, HpCallback()])

    # export
    logging.info('Exporting to {}'.format(export_dir))
    tf.saved_model.save(model, export_dir)

    # write out final metric
    final_rmse = history.history[METRIC][-1]
    logging.info("Validation metric {} on {} samples = {}".format(METRIC, num_eval_examples, final_rmse))

    if (not DEVELOP_MODE) and (test_data_pattern is not None) and (not SKIP_FULL_EVAL):
        logging.info("Evaluating over full test dataset")
        test_dataset = read_dataset(test_data_pattern, eval_batch_size, tf.estimator.ModeKeys.EVAL, None)
        final_metrics = model.evaluate(test_dataset)
        logging.info("Final metrics on full test dataset = {}".format(final_metrics))
    else:
        logging.info("Skipping evaluation on full test dataset")