in legacy/models/resnet/mxnet/train_imagenet.py [0:0]
def train(ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
net.initialize(initializer, ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, kvstore=kv)
L = gluon.loss.SoftmaxCrossEntropyLoss()
best_val_score = 1
for epoch in range(opt.num_epochs):
tic = time.time()
if opt.use_rec:
train_data.reset()
acc_top1.reset()
btic = time.time()
for i, batch in enumerate(train_data):
data, label = batch_fn(batch, ctx)
with ag.record():
outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
for l in loss:
l.backward()
trainer.step(batch_size)
acc_top1.update(label, outputs)
if opt.log_interval and not (i+1)%opt.log_interval:
_, top1 = acc_top1.get()
err_top1 = 1-top1
logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\tlr=%f\taccuracy=%f'%(
epoch, i, batch_size*opt.log_interval/(time.time()-btic), trainer.learning_rate, top1))
btic = time.time()
_, top1 = acc_top1.get()
err_top1 = 1-top1
throughput = int(batch_size * i /(time.time() - tic))
err_top1_val, err_top5_val = test(ctx, val_data)
logging.info('[Epoch %d] Train-accuracy=%f'%(epoch, top1))
logging.info('[Epoch %d] Speed: %d samples/sec\tTime cost=%f'%(epoch, throughput, time.time()-tic))
logging.info('[Epoch %d] Validation-accuracy=%f'%(epoch, 1 - err_top1_val))
logging.info('[Epoch %d] Validation-top_k_accuracy_5=%f'%(epoch, 1 - err_top5_val))
if save_frequency and err_top1_val < best_val_score and epoch > 50:
best_val_score = err_top1_val
net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
if save_frequency and save_dir:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))