in examples/mxnet_imagenet_resnet50.py [0:0]
def train_gluon():
def evaluate(epoch):
if not args.use_rec:
return
val_data.reset()
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
for _, batch in enumerate(val_data):
data, label = get_data_label(batch, context)
output = net(data.astype(args.dtype, copy=False))
acc_top1.update([label], [output])
acc_top5.update([label], [output])
top1_name, top1_acc = acc_top1.get()
top5_name, top5_acc = acc_top5.get()
logging.info('Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f',
epoch, rank, top1_name, top1_acc, top5_name, top5_acc)
# Hybridize and initialize model
net.hybridize()
net.initialize(initializer, ctx=context)
# Horovod: fetch and broadcast parameters
params = net.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)
# Create optimizer
optimizer_params = {'wd': args.wd,
'momentum': args.momentum,
'lr_scheduler': lr_sched}
if args.dtype == 'float16':
optimizer_params['multi_precision'] = True
opt = mx.optimizer.create('sgd', **optimizer_params)
# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt,
gradient_predivide_factor=args.gradient_predivide_factor)
# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()
# Train model
for epoch in range(args.num_epochs):
tic = time.time()
if args.use_rec:
train_data.reset()
metric.reset()
btic = time.time()
for nbatch, batch in enumerate(train_data, start=1):
data, label = get_data_label(batch, context)
with autograd.record():
output = net(data.astype(args.dtype, copy=False))
loss = loss_fn(output, label)
loss.backward()
trainer.step(batch_size)
metric.update([label], [output])
if args.log_interval and nbatch % args.log_interval == 0:
name, acc = metric.get()
logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f',
epoch, rank, nbatch, name, acc, trainer.learning_rate)
if rank == 0:
batch_speed = num_workers * batch_size * args.log_interval / (time.time() - btic)
logging.info('Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec',
epoch, nbatch, batch_speed)
btic = time.time()
# Report metrics
elapsed = time.time() - tic
_, acc = metric.get()
logging.info('Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-accuracy=%f',
epoch, rank, nbatch, elapsed, acc)
if rank == 0:
epoch_speed = num_workers * batch_size * nbatch / elapsed
logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch, epoch_speed)
# Evaluate performance
if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0:
evaluate(epoch)
# Save model
if args.save_frequency and (epoch + 1) % args.save_frequency == 0:
net.export('%s-%d' % (args.model, rank), epoch=epoch)
# Evaluate performance at the end of training
evaluate(epoch)