def train()

in djl/trainCiFar10.py [0:0]


def train(opt, ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    num_workers = 1
    transform = transforms.Compose([
        # Transpose the image from height*width*num_channels to num_channels*height*width
        # and map values from [0, 255] to [0,1]
        transforms.ToTensor(),
        # Normalize the image with mean and standard deviation calculated across all images
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=True).transform_first(transform),
        batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

    # Set train=False for validation data
    val_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=False).transform_first(transform),
        batch_size=batch_size, shuffle=False, num_workers=num_workers)

    net.collect_params().reset_ctx(ctx)
    trainer = gluon.Trainer(net.collect_params(), 'adam',
                            optimizer_params={'learning_rate': opt.lr,
                                              'wd': opt.wd,
                                              'multi_precision': False},
                            kvstore=kv)
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    total_time = 0
    num_epochs = 0
    best_acc = [0]
    for epoch in range(opt.start_epoch, opt.epochs):
        trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
        tic = time.time()
        metric.reset()
        btic = time.time()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
            outputs = []
            Ls = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    L = loss(z, y)
                    # store the loss and do backward after we have done forward
                    # on all GPUs for better speed on multiple GPUs.
                    Ls.append(L)
                    outputs.append(z)
                ag.backward(Ls)
            trainer.step(batch[0].shape[0])
            metric.update(label, outputs)
            if opt.log_interval and not (i+1)%opt.log_interval:
                name, acc = metric.get()
                logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'%(
                               epoch, i, batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1]))
            btic = time.time()

        epoch_time = time.time()-tic

        # First epoch will usually be much slower than the subsequent epics,
        # so don't factor into the average
        if num_epochs > 0:
          total_time = total_time + epoch_time
        num_epochs = num_epochs + 1

        name, acc = metric.get()
        logger.info('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name[0], acc[0], name[1], acc[1]))
        logger.info('[Epoch %d] time cost: %f'%(epoch, epoch_time))
        name, val_acc = test(ctx, val_data)
        logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))

        # save model if meet requirements
        save_checkpoint(epoch, val_acc[0], best_acc)
    if num_epochs > 1:
        print('Average epoch time: {}'.format(float(total_time)/(num_epochs - 1)))