in tf-distribution-options/code/train_ps.py [0:0]
def main(args):
if 'sourcedir.tar.gz' in args.tensorboard_dir:
tensorboard_dir = re.sub('source/sourcedir.tar.gz', 'model', args.tensorboard_dir)
else:
tensorboard_dir = args.tensorboard_dir
logging.info("Writing TensorBoard logs to {}".format(tensorboard_dir))
logging.info("getting data")
train_dataset = process_input(args.epochs, args.batch_size, args.train, 'train', args.data_config)
eval_dataset = process_input(args.epochs, args.batch_size, args.eval, 'eval', args.data_config)
validation_dataset = process_input(args.epochs, args.batch_size, args.validation, 'validation', args.data_config)
logging.info("configuring model")
logging.info("Hosts: "+ os.environ.get('SM_HOSTS'))
size = len(args.hosts)
#Deal with this
model = get_model(args.learning_rate, args.weight_decay, args.optimizer, args.momentum, size)
callbacks = []
if args.current_host == args.hosts[0]:
callbacks.append(ModelCheckpoint(args.output_data_dir + '/checkpoint-{epoch}.h5'))
callbacks.append(CustomTensorBoardCallback(log_dir=tensorboard_dir))
logging.info("Starting training")
history = model.fit(x=train_dataset[0],
y=train_dataset[1],
steps_per_epoch=(num_examples_per_epoch('train') // args.batch_size) // size,
epochs=args.epochs,
validation_data=validation_dataset,
validation_steps=(num_examples_per_epoch('validation') // args.batch_size) // size, callbacks=callbacks)
score = model.evaluate(eval_dataset[0],
eval_dataset[1],
steps=num_examples_per_epoch('eval') // args.batch_size,
verbose=0)
logging.info('Test loss:{}'.format(score[0]))
logging.info('Test accuracy:{}'.format(score[1]))
# PS: Save model and history only on worker 0
if args.current_host == args.hosts[0]:
save_history(args.model_dir + "/ps_history.p", history)
save_model(model, args.model_dir)