in examples/mxnet_imagenet_resnet50.py [0:0]
def train_module():
# Create input symbol
data = mx.sym.var('data')
if args.dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
net.cast(np.float16)
# Create output symbol
out = net(data)
if args.dtype == 'float16':
out = mx.sym.Cast(data=out, dtype=np.float32)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
# Create model
mod = mx.mod.Module(softmax, context=context)
# Initialize parameters
if args.use_pretrained:
arg_params = {}
for x in net.collect_params().values():
x.reset_ctx(mx.cpu())
arg_params[x.name] = x.data()
else:
arg_params = None
aux_params = None
mod.bind(data_shapes=train_data.provide_data,
label_shapes=train_data.provide_label)
mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params)
# Horovod: fetch and broadcast parameters
(arg_params, aux_params) = mod.get_params()
if arg_params is not None:
hvd.broadcast_parameters(arg_params, root_rank=0)
if aux_params is not None:
hvd.broadcast_parameters(aux_params, root_rank=0)
mod.set_params(arg_params=arg_params, aux_params=aux_params)
# Create optimizer
# Note that when using Module API, we need to specify rescale_grad since
# we create optimizer first and wrap it with DistributedOptimizer. For
# Gluon API, it is handled in Trainer.step() function so there is no need
# to specify rescale_grad (see above train_gluon() function).
optimizer_params = {'wd': args.wd,
'momentum': args.momentum,
'rescale_grad': 1.0 / batch_size,
'lr_scheduler': lr_sched}
if args.dtype == 'float16':
optimizer_params['multi_precision'] = True
opt = mx.optimizer.create('sgd', **optimizer_params)
# Horovod: wrap optimizer with DistributedOptimizer
dist_opt = hvd.DistributedOptimizer(opt,
gradient_predivide_factor=args.gradient_predivide_factor)
# Setup validation data and callback during training
eval_data = None
if args.eval_epoch:
eval_data = val_data
batch_callback = None
if args.log_interval > 0 and rank == 0:
batch_callback = mx.callback.Speedometer(batch_size * num_workers,
args.log_interval)
epoch_callback = None
if args.save_frequency > 0:
epoch_callback = mx.callback.do_checkpoint(
'%s-%d' % (args.model, rank),
period=args.save_frequency)
# Train model
mod.fit(train_data,
eval_data=eval_data,
num_epoch=args.num_epochs,
kvstore=None,
batch_end_callback=batch_callback,
epoch_end_callback=epoch_callback,
optimizer=dist_opt)
# Evaluate performance if not using synthetic data
if args.use_rec:
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
res = mod.score(val_data, [acc_top1, acc_top5])
for name, val in res:
logging.info('Epoch[%d] Rank[%d] Validation-%s=%f',
args.num_epochs - 1, rank, name, val)