def main()

in src/train.py [0:0]


def main(args):
    mpi = False
    if 'sagemaker_mpi_enabled' in args.fw_params:
        if args.fw_params['sagemaker_mpi_enabled']:
            import horovod.keras as hvd
            mpi = True
            # Horovod: initialize Horovod.
            hvd.init()

            # Horovod: pin GPU to be used to process local rank (one GPU per process)
            gpus = tf.config.experimental.list_physical_devices('GPU')
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            if gpus:
                tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
    else:
        hvd = None

    callbacks = []
    if mpi:
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
        callbacks.append(hvd.callbacks.MetricAverageCallback())

        if hvd.rank() == 0:
            callbacks.append(ModelCheckpoint(args.output_dir + '/checkpoint-{epoch}.ckpt',
                                             save_weights_only=True,
                                             verbose=2))
    else:
        callbacks.append(ModelCheckpoint(args.output_dir + '/checkpoint-{epoch}.ckpt',
                                         save_weights_only=True,
                                         verbose=2))

    current_host = os.environ['SM_CURRENT_HOST']
    print("The current horovod rank is ", hvd.rank())
    print("the current host is ", current_host)

    print("Training dataset being loaded -----------------")
    train_dataset = train_input_fn(hvd, mpi)

    print("valid dataset being loaded -----------------")
    valid_dataset = valid_input_fn(hvd, mpi)

    print("Test dataset being loaded -----------------")
    test_dataset = test_input_fn()

    logging.info("configuring model")
    model = model_def(args.learning_rate, mpi, hvd)

    logging.info("Starting training")

    size = 1
    if mpi:
        size = hvd.size()
    print("the size is ", size)

    # Fit the model
    model.fit(train_dataset,
              steps_per_epoch=((args.num_train // args.batch_size) // size),
              epochs=args.epochs,
              validation_data=valid_dataset,
              validation_steps=((args.num_val // args.batch_size) // size),
              callbacks=callbacks,
              verbose=2)

    # Evaluate the model at rank 0
    if not mpi or (mpi and hvd.rank() == 0):
        print("-------------------------Evaluation begins ----------------------------------------------------")

        # Accumulate per-slide predictions
        pred_dict = {}
        for i, element in enumerate(test_dataset):
            if (i + 1) % 1000 == 0:
                print("Computing scores for tile {}...".format(i + 1))
                logging.info("Computing scores for slide {}...".format(i + 1))

            image = element[0].numpy()
            label = element[1].numpy()
            slide = element[2].numpy().decode()

            if slide not in pred_dict.keys():
                pred_dict[slide] = {'y_true': label, 'y_pred': []}
            pred = model.predict(np.expand_dims(image, axis=0))[0]
            pred_dict[slide]['y_pred'].append(pred)

        # Aggregate per-slide predictions
        y_true = []
        y_pred = []
        for key, value in pred_dict.items():
            slide_true = value['y_true']
            pred_scores_list = value['y_pred']
            mean_pred_scores = np.mean(np.vstack(pred_scores_list), axis=0)
            mean_pred_class = np.argmax(mean_pred_scores)

            y_true.append(slide_true)
            y_pred.append(mean_pred_class)

            print('Slide {}: True Label = {}, Prediction = {}'.format(key, slide_true, mean_pred_class))
            logging.info('Slide {}: True Label = {}, Prediction = {}'.format(key, slide_true, mean_pred_class))

        acc = accuracy_score(y_true, y_pred)
        print('Per-Slide Test accuracy: {}'.format(acc))
        logging.info('Per-Slide Test accuracy: {}'.format(acc))

    if mpi:
        if hvd.rank() == 0:
            model_path = '{}/00000001'.format(args.model_output_dir)
            model.save(model_path)
    else:
        model_path = '{}/00000001'.format(args.model_output_dir)
        model.save(model_path)
        model.save(args.model_output_dir)