def main()

in tensorflow_managed_spot_training_checkpointing/source_dir/cifar10_keras_main.py [0:0]


def main(args):
    if os.path.isdir(args.checkpoint_path):
        logging.info("Checkpointing directory {} exists".format(args.checkpoint_path))
    else:
        logging.info("Creating Checkpointing directory {}".format(args.checkpoint_path))
        os.mkdir(args.checkpoint_path)

    logging.info("getting data")
    train_dataset = train_input_fn()
    eval_dataset = eval_input_fn()
    validation_dataset = validation_input_fn()

    logging.info("configuring model")
    
    # Load model
    if not os.listdir(args.checkpoint_path):
        model = keras_model_fn(args.learning_rate, args.weight_decay, args.optimizer, args.momentum)
        initial_epoch_number = 0
    else:    
        model, initial_epoch_number = load_model_from_checkpoints(args.checkpoint_path)
         
    logging.info("Checkpointing to: {}".format(args.checkpoint_path))

    callbacks = []
    callbacks.append(keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1))
    callbacks.append(ModelCheckpoint(args.checkpoint_path + '/checkpoint-{epoch}.h5'))

    logging.info("Starting training from epoch: {}".format(initial_epoch_number+1))
    
    size = 1
    model.fit(x=train_dataset[0],
              y=train_dataset[1],
              steps_per_epoch=(num_examples_per_epoch('train') // args.batch_size) // size,
              epochs=args.epochs,
              initial_epoch=initial_epoch_number,
              validation_data=validation_dataset,
              validation_steps=(num_examples_per_epoch('validation') // args.batch_size) // size,
              callbacks=callbacks)

    score = model.evaluate(eval_dataset[0],
                           eval_dataset[1],
                           steps=num_examples_per_epoch('eval') // args.batch_size,
                           verbose=0)

    logging.info('Test loss:{}'.format(score[0]))
    logging.info('Test accuracy:{}'.format(score[1]))

    save_model(model, args.model_output_dir)