def train_and_evaluate()

in 07_training/serverlessml/flowers/classifier/train.py [0:0]


def train_and_evaluate(strategy, opts):
    # calculate the image dimensions given that we have to center crop
    # to achieve the model image size
    IMG_HEIGHT = IMG_WIDTH = round(MODEL_IMG_SIZE / opts['crop_ratio'])
    print('Will pad input images to {}x{}, then crop them to {}x{}'.format(
        IMG_HEIGHT, IMG_WIDTH, MODEL_IMG_SIZE, MODEL_IMG_SIZE
    ))
    IMG_CHANNELS = 3

    train_dataset = create_preproc_dataset(
        os.path.join(opts['input_topdir'], 'train' + opts['pattern']),
        IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS
    ).batch(opts['batch_size'])
    eval_dataset = create_preproc_dataset(
        os.path.join(opts['input_topdir'], 'valid' + opts['pattern']),
        IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS
    ).batch(opts['batch_size'])

    # if number of training examples per epoch is specified
    # repeat the training dataset indefinitely
    num_steps_per_epoch = None
    if (opts['num_training_examples'] > 0):
        train_dataset = train_dataset.repeat()
        num_steps_per_epoch = opts['num_training_examples'] // opts['batch_size']
        print("Will train for {} steps".format(num_steps_per_epoch))
    
    # checkpoint and early stopping callbacks
    model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(opts['outdir'], 'chkpts'),
        monitor='val_accuracy', mode='max',
        save_best_only=True)
    early_stopping_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', mode='max',
        patience=2)
    
    # model training
    with strategy.scope():
        model = create_model(opts, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=opts['lrate']),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(
                      from_logits=False),
                  metrics=['accuracy']
                 )
    print(model.summary())
    history = model.fit(train_dataset, 
                        validation_data=eval_dataset,
                        epochs=opts['num_epochs'],
                        steps_per_epoch=num_steps_per_epoch,
                        callbacks=[model_checkpoint_cb, early_stopping_cb]
                       )
    training_plot(['loss', 'accuracy'], history, 
                  os.path.join(opts['outdir'], 'training_plot.png'))
    
    # export the model
    export_model(model, 
                 opts['outdir'],
                 IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
    
    # report hyperparam metric
    hpt = hypertune.HyperTune()
    accuracy = np.max(history.history['val_accuracy']) # highest encountered
    nepochs = len(history.history['val_accuracy'])
    hpt.report_hyperparameter_tuning_metric(
        hyperparameter_metric_tag='accuracy',
        metric_value=accuracy,
        global_step=nepochs)
    print("Reported hparam metric name=accuracy value={}".format(accuracy))
    
    return model