in tf-horovod-inference-pipeline/train.py [0:0]
def main(args):
mpi = False
if 'sourcedir.tar.gz' in args.tensorboard_dir:
tensorboard_dir = re.sub('source/sourcedir.tar.gz', 'model', args.tensorboard_dir)
else:
tensorboard_dir = args.tensorboard_dir
logging.info("Writing TensorBoard logs to {}".format(tensorboard_dir))
if 'sagemaker_mpi_enabled' in args.fw_params:
if args.fw_params['sagemaker_mpi_enabled']:
import horovod.tensorflow.keras as hvd
mpi = True
# Horovod: initialize Horovod.
hvd.init()
# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
K.set_session(tf.Session(config=config))
else:
hvd = None
logging.info("Running with MPI={}".format(mpi))
logging.info("getting data")
train_dataset = train_input_fn()
eval_dataset = eval_input_fn()
validation_dataset = validation_input_fn()
logging.info("configuring model")
model = keras_model_fn(args.learning_rate, args.weight_decay, args.optimizer, args.momentum, mpi, hvd)
callbacks = []
if mpi:
callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
callbacks.append(hvd.callbacks.MetricAverageCallback())
callbacks.append(hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=1))
callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1))
if hvd.rank() == 0:
callbacks.append(ModelCheckpoint(args.output_dir + '/checkpoint-{epoch}.h5'))
callbacks.append(CustomTensorBoardCallback(log_dir=tensorboard_dir))
else:
callbacks.append(ModelCheckpoint(args.output_dir + '/checkpoint-{epoch}.h5'))
callbacks.append(CustomTensorBoardCallback(log_dir=tensorboard_dir))
logging.info("Starting training")
size = 1
if mpi:
size = hvd.size()
model.fit(x=train_dataset[0], y=train_dataset[1],
steps_per_epoch=(num_examples_per_epoch('train') // args.batch_size) // size,
epochs=args.epochs, validation_data=validation_dataset,
validation_steps=(num_examples_per_epoch('validation') // args.batch_size) // size, callbacks=callbacks)
score = model.evaluate(eval_dataset[0], eval_dataset[1], steps=num_examples_per_epoch('eval') // args.batch_size,
verbose=0)
logging.info('Test loss:{}'.format(score[0]))
logging.info('Test accuracy:{}'.format(score[1]))
# Horovod: Save model only on worker 0 (i.e. master)
if mpi:
if hvd.rank() == 0:
return save_model(model, args.model_output_dir)
else:
return save_model(model, args.model_output_dir)