in notebooks/escience_series/mnist.py [0:0]
def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel,
hosts, current_host, model_dir):
(train_labels, train_images) = load_data(training_channel)
(test_labels, test_images) = load_data(testing_channel)
CHECKPOINTS_DIR = '/opt/ml/checkpoints'
checkpoints_enabled = os.path.exists(CHECKPOINTS_DIR)
# Data parallel training - shard the data so each host
# only trains on a subset of the total data.
shard_size = len(train_images) // len(hosts)
for i, host in enumerate(hosts):
if host == current_host:
start = shard_size * i
end = start + shard_size
break
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size,
shuffle=True)
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
logging.getLogger().setLevel(logging.DEBUG)
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
mlp_model = mx.mod.Module(symbol=build_graph(),
context=get_training_context(num_gpus))
checkpoint_callback = None
if checkpoints_enabled:
# Create a checkpoint callback that checkpoints the model params and the optimizer state after every epoch at the given path.
checkpoint_callback = mx.callback.module_checkpoint(mlp_model,
CHECKPOINTS_DIR + "/mnist",
period=1,
save_optimizer_states=True)
mlp_model.fit(train_iter,
eval_data=val_iter,
kvstore=kvstore,
optimizer='sgd',
optimizer_params={'learning_rate': learning_rate},
eval_metric='acc',
epoch_end_callback = checkpoint_callback,
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
num_epoch=epochs)
if current_host == hosts[0]:
save(model_dir, mlp_model)