in run_vr.py [0:0]
def recalibrate(epoch, args, train_loader, test_loader, model, train_dataset, optimizer, criterion):
if args.vr_bn_at_recalibration:
model.train()
else:
model.eval()
logging.info("Recalibration pass starting")
if hasattr(optimizer, "recalibrate_start"):
optimizer.recalibrate_start()
start = timer()
#logging.info("Recalibration loop ...")
if optimizer.epoch >= optimizer.vr_from_epoch and args.method != "online_svrg" and args.method != "scsg":
for batch_idx, (data, target) in enumerate(train_loader):
batch_id = batch_idx
#pdb.set_trace()
if args.cuda:
data, target = data.cuda(), target.cuda(non_blocking=True)
data, target = Variable(data), Variable(target)
#print("recal:")
#print(data[:2].data.cpu().numpy())
def eval_closure():
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.recalibrate(batch_id, closure=eval_closure)
if batch_idx % args.log_interval == 0:
mid = timer()
percent_done = 100. * batch_idx / len(train_loader)
if percent_done > 0:
time_estimate = math.ceil((mid - start)*(100/percent_done))
time_estimate = str(datetime.timedelta(seconds=time_estimate))
else:
time_estimate = "unknown"
logging.info('Recal Epoch: {} [{}/{} ({:.0f}%)] estimate: {}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
percent_done, time_estimate))
if hasattr(optimizer, "recalibrate_end"):
optimizer.recalibrate_end()
logging.info("Recalibration finished")