def main()

in source_directory/training/training_script.py [0:0]


def main(args):
    
    if args.use_horovod:
        ## set up horovod for distributed training (multiple instances with multi-gpu)
        hvd.init()
        size = hvd.size()
        print("Horovod size:", size)
        print("Local horovod rank:", hvd.local_rank())
        print("Global horovod rank:", hvd.rank())

        gpus = tf.config.experimental.list_physical_devices('GPU')
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        if gpus:
            tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
        
    else:
        ## set up replicas for multiple gpus
        strategy = tf.distribute.MirroredStrategy()
        print('Number of devices: {}'.format(strategy.num_replicas_in_sync))


    ## create and compile the model
    print("Creating model")
    if args.use_horovod:
        model = create_model()
        distributed_learning_rate = size*args.learning_rate
        optimizer = Adam(lr=distributed_learning_rate, decay=args.weight_decay)
        optimizer = hvd.DistributedOptimizer(optimizer)
        print("Compiling model")
        model.compile(loss=CategoricalCrossentropy(),
                      optimizer=optimizer,
                      experimental_run_tf_function=False,
                      metrics=[tf.keras.metrics.CategoricalAccuracy()])
    
    else:
        with strategy.scope():
            model = create_model()
            optimizer = Adam(lr=args.learning_rate, decay=args.weight_decay)

            ## compile model
            print("Compiling model")
            model.compile(loss=CategoricalCrossentropy(),
                          optimizer=optimizer,
                          experimental_run_tf_function=False,
                          metrics=[tf.keras.metrics.CategoricalAccuracy()],
                         )
    
    
    ## set up callbacks
    logging.info("Setting callbacks")
    tfLearningRatePlateau = tf.keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1)
    log_dir = './tf_logs/'
    verbose = 0
    if args.use_horovod:
        callbacks = [
            hvd.callbacks.BroadcastGlobalVariablesCallback(0),
            hvd.callbacks.MetricAverageCallback(),
            tfLearningRatePlateau,
        ]

        if hvd.rank() == 0:
            callbacks.append(TensorBoard(log_dir=log_dir))
            callbacks.append(Sync2S3(log_dir=log_dir, s3log_dir=args.tensorboard_logs_s3uri))
            verbose = 2

    else:
        callbacks = [
            tfLearningRatePlateau,
            TensorBoard(log_dir=log_dir),
            Sync2S3(log_dir=log_dir, s3log_dir=args.tensorboard_logs_s3uri),
        ]
        verbose = 2

        
    ## load the datasets
    print("Loading datasets")
    train_dataset, num_train_batches_per_epoch = load_dataset(
        args.epochs, args.batch_size, 'train')
    validation_dataset, num_validation_batches_per_epoch = load_dataset(
        args.epochs, args.batch_size, 'validation')
    
    
    ## start training
    # https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit    
    print("Starting training")
    model.fit(x=train_dataset,
              steps_per_epoch=num_train_batches_per_epoch,
              epochs=args.epochs,
              validation_data=validation_dataset,
              validation_steps=num_validation_batches_per_epoch,
              verbose=2,
              callbacks=callbacks,
             )

    
    ## save model
    if args.use_horovod:
        if hvd.rank()==0:
            save_model(model, args.model_output_dir)
    else:
        save_model(model, args.model_output_dir)
        
    return