in gossip_sgd_adpsgd.py [0:0]
def train(model, criterion, optimizer, batch_meter, data_meter, nn_meter,
loader, epoch, itr, begin_time):
losses = Meter(ptag='Loss')
top1 = Meter(ptag='Prec@1')
top5 = Meter(ptag='Prec@5')
# switch to train mode
model.train()
# spoof sampler to continue from checkpoint w/o loading data all over again
_train_loader = loader.__iter__()
for i in range(itr):
try:
next(_train_loader.sample_iter)
except Exception:
# finished epoch but prempted before state was updated
log.info('Loader spoof error attempt {}/{}'.format(i, len(loader)))
return
log.debug('Training (epoch {})'.format(epoch))
model.enable_gossip()
batch_time = time.time()
for i, (batch, target) in enumerate(_train_loader, start=itr):
target = target.cuda(non_blocking=True)
data_meter.update(time.time() - batch_time)
# ----------------------------------------------------------- #
# Forward/Backward pass
# ----------------------------------------------------------- #
nn_time = time.time()
output = model(batch)
loss = criterion(output, target)
bilat_freq = 100
if i == 0:
update_global_iteration_counter(itr=1,
itr_per_epoch=len(loader))
update_bilat_learning_rate(model, itr_per_epoch=len(loader))
elif (i + args.rank) % (bilat_freq) == 0:
update_global_iteration_counter(itr=bilat_freq,
itr_per_epoch=len(loader))
update_bilat_learning_rate(model, itr_per_epoch=len(loader))
loss.backward()
update_learning_rate(optimizer, epoch, itr=i,
itr_per_epoch=len(loader))
optimizer.step() # optimization update
optimizer.zero_grad()
nn_meter.update(time.time() - nn_time)
# ----------------------------------------------------------- #
batch_meter.update(time.time() - batch_time)
batch_time = time.time()
log_time = time.time()
# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), batch.size(0))
top1.update(prec1.item(), batch.size(0))
top5.update(prec5.item(), batch.size(0))
if i % args.print_freq == 0:
ep = args.global_epoch
itr = args.global_itr % (len(loader) * args.world_size)
with open(args.out_fname, '+a') as f:
print('{ep},{itr},{bt},{nt},{dt},'
'{loss.val:.4f},{loss.avg:.4f},'
'{top1.val:.3f},{top1.avg:.3f},'
'{top5.val:.3f},{top5.avg:.3f},-1'
.format(ep=ep, itr=itr,
bt=batch_meter,
dt=data_meter, nt=nn_meter,
loss=losses, top1=top1,
top5=top5), file=f)
log_time = time.time() - log_time
log.debug(log_time)
with open(args.out_fname, '+a') as f:
print('{ep},{itr},{bt},{nt},{dt},'
'{loss.val:.4f},{loss.avg:.4f},'
'{top1.val:.3f},{top1.avg:.3f},'
'{top5.val:.3f},{top5.avg:.3f},-1'
.format(ep=epoch, itr=i,
bt=batch_meter,
dt=data_meter, nt=nn_meter,
loss=losses, top1=top1,
top5=top5), file=f)