in source_directory/training/training_script.py [0:0]
def main(args):
if args.use_horovod:
## set up horovod for distributed training (multiple instances with multi-gpu)
hvd.init()
size = hvd.size()
print("Horovod size:", size)
print("Local horovod rank:", hvd.local_rank())
print("Global horovod rank:", hvd.rank())
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
else:
## set up replicas for multiple gpus
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
## create and compile the model
print("Creating model")
if args.use_horovod:
model = create_model()
distributed_learning_rate = size*args.learning_rate
optimizer = Adam(lr=distributed_learning_rate, decay=args.weight_decay)
optimizer = hvd.DistributedOptimizer(optimizer)
print("Compiling model")
model.compile(loss=CategoricalCrossentropy(),
optimizer=optimizer,
experimental_run_tf_function=False,
metrics=[tf.keras.metrics.CategoricalAccuracy()])
else:
with strategy.scope():
model = create_model()
optimizer = Adam(lr=args.learning_rate, decay=args.weight_decay)
## compile model
print("Compiling model")
model.compile(loss=CategoricalCrossentropy(),
optimizer=optimizer,
experimental_run_tf_function=False,
metrics=[tf.keras.metrics.CategoricalAccuracy()],
)
## set up callbacks
logging.info("Setting callbacks")
tfLearningRatePlateau = tf.keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1)
log_dir = './tf_logs/'
verbose = 0
if args.use_horovod:
callbacks = [
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
hvd.callbacks.MetricAverageCallback(),
tfLearningRatePlateau,
]
if hvd.rank() == 0:
callbacks.append(TensorBoard(log_dir=log_dir))
callbacks.append(Sync2S3(log_dir=log_dir, s3log_dir=args.tensorboard_logs_s3uri))
verbose = 2
else:
callbacks = [
tfLearningRatePlateau,
TensorBoard(log_dir=log_dir),
Sync2S3(log_dir=log_dir, s3log_dir=args.tensorboard_logs_s3uri),
]
verbose = 2
## load the datasets
print("Loading datasets")
train_dataset, num_train_batches_per_epoch = load_dataset(
args.epochs, args.batch_size, 'train')
validation_dataset, num_validation_batches_per_epoch = load_dataset(
args.epochs, args.batch_size, 'validation')
## start training
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
print("Starting training")
model.fit(x=train_dataset,
steps_per_epoch=num_train_batches_per_epoch,
epochs=args.epochs,
validation_data=validation_dataset,
validation_steps=num_validation_batches_per_epoch,
verbose=2,
callbacks=callbacks,
)
## save model
if args.use_horovod:
if hvd.rank()==0:
save_model(model, args.model_output_dir)
else:
save_model(model, args.model_output_dir)
return