in djl/trainCiFar10.py [0:0]
def train(opt, ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
num_workers = 1
transform = transforms.Compose([
# Transpose the image from height*width*num_channels to num_channels*height*width
# and map values from [0, 255] to [0,1]
transforms.ToTensor(),
# Normalize the image with mean and standard deviation calculated across all images
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
train_data = gluon.data.DataLoader(
gluon.data.vision.CIFAR10(train=True).transform_first(transform),
batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
# Set train=False for validation data
val_data = gluon.data.DataLoader(
gluon.data.vision.CIFAR10(train=False).transform_first(transform),
batch_size=batch_size, shuffle=False, num_workers=num_workers)
net.collect_params().reset_ctx(ctx)
trainer = gluon.Trainer(net.collect_params(), 'adam',
optimizer_params={'learning_rate': opt.lr,
'wd': opt.wd,
'multi_precision': False},
kvstore=kv)
loss = gluon.loss.SoftmaxCrossEntropyLoss()
total_time = 0
num_epochs = 0
best_acc = [0]
for epoch in range(opt.start_epoch, opt.epochs):
trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
tic = time.time()
metric.reset()
btic = time.time()
for i, batch in enumerate(train_data):
data = gluon.utils.split_and_load(batch[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch[1].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
outputs = []
Ls = []
with ag.record():
for x, y in zip(data, label):
z = net(x)
L = loss(z, y)
# store the loss and do backward after we have done forward
# on all GPUs for better speed on multiple GPUs.
Ls.append(L)
outputs.append(z)
ag.backward(Ls)
trainer.step(batch[0].shape[0])
metric.update(label, outputs)
if opt.log_interval and not (i+1)%opt.log_interval:
name, acc = metric.get()
logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'%(
epoch, i, batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1]))
btic = time.time()
epoch_time = time.time()-tic
# First epoch will usually be much slower than the subsequent epics,
# so don't factor into the average
if num_epochs > 0:
total_time = total_time + epoch_time
num_epochs = num_epochs + 1
name, acc = metric.get()
logger.info('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name[0], acc[0], name[1], acc[1]))
logger.info('[Epoch %d] time cost: %f'%(epoch, epoch_time))
name, val_acc = test(ctx, val_data)
logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))
# save model if meet requirements
save_checkpoint(epoch, val_acc[0], best_acc)
if num_epochs > 1:
print('Average epoch time: {}'.format(float(total_time)/(num_epochs - 1)))