in tensorflow_managed_spot_training_checkpointing/source_dir/cifar10_keras_main.py [0:0]
def main(args):
if os.path.isdir(args.checkpoint_path):
logging.info("Checkpointing directory {} exists".format(args.checkpoint_path))
else:
logging.info("Creating Checkpointing directory {}".format(args.checkpoint_path))
os.mkdir(args.checkpoint_path)
logging.info("getting data")
train_dataset = train_input_fn()
eval_dataset = eval_input_fn()
validation_dataset = validation_input_fn()
logging.info("configuring model")
# Load model
if not os.listdir(args.checkpoint_path):
model = keras_model_fn(args.learning_rate, args.weight_decay, args.optimizer, args.momentum)
initial_epoch_number = 0
else:
model, initial_epoch_number = load_model_from_checkpoints(args.checkpoint_path)
logging.info("Checkpointing to: {}".format(args.checkpoint_path))
callbacks = []
callbacks.append(keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1))
callbacks.append(ModelCheckpoint(args.checkpoint_path + '/checkpoint-{epoch}.h5'))
logging.info("Starting training from epoch: {}".format(initial_epoch_number+1))
size = 1
model.fit(x=train_dataset[0],
y=train_dataset[1],
steps_per_epoch=(num_examples_per_epoch('train') // args.batch_size) // size,
epochs=args.epochs,
initial_epoch=initial_epoch_number,
validation_data=validation_dataset,
validation_steps=(num_examples_per_epoch('validation') // args.batch_size) // size,
callbacks=callbacks)
score = model.evaluate(eval_dataset[0],
eval_dataset[1],
steps=num_examples_per_epoch('eval') // args.batch_size,
verbose=0)
logging.info('Test loss:{}'.format(score[0]))
logging.info('Test accuracy:{}'.format(score[1]))
save_model(model, args.model_output_dir)