def main()

in _archiving/contribution/daekeun-ml/tensorflow-in-sagemaker-workshop/training_script/cifar10_keras_dist_solution.py [0:0]


def main(args):
    # ----- 추가 부분 -----
    import horovod.keras as hvd
    hvd.init()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list =  str(hvd.local_rank())
    K.set_session(tf.Session(config = config))

    logging.info("getting data")
    train_dataset = train_input_fn()
    eval_dataset = eval_input_fn()
    validation_dataset = validation_input_fn()

    logging.info("configuring model")
    # ----- 수정 부분 -----
    model = keras_model_fn(args.learning_rate, args.weight_decay, args.optimizer, args.momentum, hvd)
    callbacks = []

    callbacks.append(ModelCheckpoint(args.model_output_dir + '/checkpoint-{epoch}.h5'))
    
    # ----- 추가 부분 -----
    callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
    callbacks.append(hvd.callbacks.MetricAverageCallback())
    callbacks.append(hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=1))
    
    # ----- 추가 부분 -----    
    if hvd.rank () == 0 :
        callbacks.append(ModelCheckpoint(args.model_output_dir + '/checkpoint-{epoch}.h5'))
        callbacks.append(TensorBoard(log_dir = args.model_output_dir, update_freq = 'epoch'))
        
    logging.info("Starting training")
    
    train_steps = num_examples_per_epoch('train') // args.batch_size
    train_steps = int(train_steps // hvd.size())
    val_steps = num_examples_per_epoch('validation') // args.batch_size
    val_steps = int(val_steps // hvd.size())
    eval_steps = num_examples_per_epoch('eval') // args.batch_size
    eval_steps = int(eval_steps // hvd.size())

    model.fit(x=train_dataset[0], y=train_dataset[1],
              steps_per_epoch=train_steps,
              epochs=int(args.epochs // hvd.size()), validation_data=validation_dataset,
              validation_steps=val_steps, callbacks=callbacks)

    score = model.evaluate(eval_dataset[0], eval_dataset[1], steps=eval_steps,
                           verbose=0)

    logging.info('Test loss:{}'.format(score[0]))
    logging.info('Test accuracy:{}'.format(score[1]))

    return save_model(model, args.model_output_dir)