in gossip_sgd.py [0:0]
def train(model, criterion, optimizer, batch_meter, data_meter, nn_meter,
loader, epoch, itr, begin_time, num_itr_ignore):
losses = Meter(ptag='Loss')
top1 = Meter(ptag='Prec@1')
top5 = Meter(ptag='Prec@5')
# switch to train mode
model.train()
# spoof sampler to continue from checkpoint w/o loading data all over again
_train_loader = loader.__iter__()
for i in range(itr):
try:
next(_train_loader.sample_iter)
except Exception:
# finished epoch but prempted before state was updated
log.info('Loader spoof error attempt {}/{}'.format(i, len(loader)))
return
log.debug('Training (epoch {})'.format(epoch))
batch_time = time.time()
for i, (batch, target) in enumerate(_train_loader, start=itr):
target = target.cuda(non_blocking=True)
# create one-hot vector from target
kl_target = torch.zeros(target.shape[0], 1000, device='cuda').scatter_(
1, target.view(-1, 1), 1)
if num_itr_ignore == 0:
data_meter.update(time.time() - batch_time)
# ----------------------------------------------------------- #
# Forward/Backward pass
# ----------------------------------------------------------- #
nn_time = time.time()
output = model(batch)
loss = criterion(output, kl_target)
loss.backward()
if i % 100 == 0:
update_learning_rate(optimizer, epoch, itr=i,
itr_per_epoch=len(loader))
optimizer.step() # optimization update
optimizer.zero_grad()
if not args.overlap and not args.all_reduce:
log.debug('Transferring params')
model.transfer_params()
if num_itr_ignore == 0:
nn_meter.update(time.time() - nn_time)
# ----------------------------------------------------------- #
if num_itr_ignore == 0:
batch_meter.update(time.time() - batch_time)
batch_time = time.time()
log_time = time.time()
# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), batch.size(0))
top1.update(prec1.item(), batch.size(0))
top5.update(prec5.item(), batch.size(0))
if i % args.print_freq == 0:
with open(args.out_fname, '+a') as f:
print('{ep},{itr},{bt},{nt},{dt},'
'{loss.val:.4f},{loss.avg:.4f},'
'{top1.val:.3f},{top1.avg:.3f},'
'{top5.val:.3f},{top5.avg:.3f},-1'
.format(ep=epoch, itr=i,
bt=batch_meter,
dt=data_meter, nt=nn_meter,
loss=losses, top1=top1,
top5=top5), file=f)
if num_itr_ignore > 0:
num_itr_ignore -= 1
log_time = time.time() - log_time
log.debug(log_time)
if (args.num_iterations_per_training_epoch != -1 and
i+1 == args.num_iterations_per_training_epoch):
break
with open(args.out_fname, '+a') as f:
print('{ep},{itr},{bt},{nt},{dt},'
'{loss.val:.4f},{loss.avg:.4f},'
'{top1.val:.3f},{top1.avg:.3f},'
'{top5.val:.3f},{top5.avg:.3f},-1'
.format(ep=epoch, itr=i,
bt=batch_meter,
dt=data_meter, nt=nn_meter,
loss=losses, top1=top1,
top5=top5), file=f)