in 07_training/serverlessml/flowers/classifier/train.py [0:0]
def train_and_evaluate(strategy, opts):
# calculate the image dimensions given that we have to center crop
# to achieve the model image size
IMG_HEIGHT = IMG_WIDTH = round(MODEL_IMG_SIZE / opts['crop_ratio'])
print('Will pad input images to {}x{}, then crop them to {}x{}'.format(
IMG_HEIGHT, IMG_WIDTH, MODEL_IMG_SIZE, MODEL_IMG_SIZE
))
IMG_CHANNELS = 3
train_dataset = create_preproc_dataset(
os.path.join(opts['input_topdir'], 'train' + opts['pattern']),
IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS
).batch(opts['batch_size'])
eval_dataset = create_preproc_dataset(
os.path.join(opts['input_topdir'], 'valid' + opts['pattern']),
IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS
).batch(opts['batch_size'])
# if number of training examples per epoch is specified
# repeat the training dataset indefinitely
num_steps_per_epoch = None
if (opts['num_training_examples'] > 0):
train_dataset = train_dataset.repeat()
num_steps_per_epoch = opts['num_training_examples'] // opts['batch_size']
print("Will train for {} steps".format(num_steps_per_epoch))
# checkpoint and early stopping callbacks
model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(opts['outdir'], 'chkpts'),
monitor='val_accuracy', mode='max',
save_best_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
monitor='val_accuracy', mode='max',
patience=2)
# model training
with strategy.scope():
model = create_model(opts, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=opts['lrate']),
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=False),
metrics=['accuracy']
)
print(model.summary())
history = model.fit(train_dataset,
validation_data=eval_dataset,
epochs=opts['num_epochs'],
steps_per_epoch=num_steps_per_epoch,
callbacks=[model_checkpoint_cb, early_stopping_cb]
)
training_plot(['loss', 'accuracy'], history,
os.path.join(opts['outdir'], 'training_plot.png'))
# export the model
export_model(model,
opts['outdir'],
IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
# report hyperparam metric
hpt = hypertune.HyperTune()
accuracy = np.max(history.history['val_accuracy']) # highest encountered
nepochs = len(history.history['val_accuracy'])
hpt.report_hyperparameter_tuning_metric(
hyperparameter_metric_tag='accuracy',
metric_value=accuracy,
global_step=nepochs)
print("Reported hparam metric name=accuracy value={}".format(accuracy))
return model