def train()

in legacy/models/resnet/mxnet/train_imagenet.py [0:0]


def train(ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    net.initialize(initializer, ctx=ctx)

    trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, kvstore=kv)

    L = gluon.loss.SoftmaxCrossEntropyLoss()

    best_val_score = 1

    for epoch in range(opt.num_epochs):
        tic = time.time()
        if opt.use_rec:
            train_data.reset()
        acc_top1.reset()
        btic = time.time()

        for i, batch in enumerate(train_data):
            data, label = batch_fn(batch, ctx)
            with ag.record():
                outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
            for l in loss:
                l.backward()
            trainer.step(batch_size)

            acc_top1.update(label, outputs)
            if opt.log_interval and not (i+1)%opt.log_interval:
                _, top1 = acc_top1.get()
                err_top1 = 1-top1
                logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\tlr=%f\taccuracy=%f'%(
                             epoch, i, batch_size*opt.log_interval/(time.time()-btic), trainer.learning_rate, top1))
                btic = time.time()

        _, top1 = acc_top1.get()
        err_top1 = 1-top1
        throughput = int(batch_size * i /(time.time() - tic))

        err_top1_val, err_top5_val = test(ctx, val_data)

        logging.info('[Epoch %d] Train-accuracy=%f'%(epoch, top1))
        logging.info('[Epoch %d] Speed: %d samples/sec\tTime cost=%f'%(epoch, throughput, time.time()-tic))
        logging.info('[Epoch %d] Validation-accuracy=%f'%(epoch, 1 - err_top1_val))
        logging.info('[Epoch %d] Validation-top_k_accuracy_5=%f'%(epoch, 1 - err_top5_val))

        if save_frequency and err_top1_val < best_val_score and epoch > 50:
            best_val_score = err_top1_val
            net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

        if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
            net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))

    if save_frequency and save_dir:
        net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))