def recalibrate()

in run_vr.py [0:0]


def recalibrate(epoch, args, train_loader, test_loader, model, train_dataset, optimizer, criterion):
    if args.vr_bn_at_recalibration:
        model.train()
    else:
        model.eval()
    logging.info("Recalibration pass starting")
    if hasattr(optimizer, "recalibrate_start"):
        optimizer.recalibrate_start()
    start = timer()

    #logging.info("Recalibration loop ...")
    if optimizer.epoch >= optimizer.vr_from_epoch and args.method != "online_svrg" and args.method != "scsg":
        for batch_idx, (data, target) in enumerate(train_loader):
            batch_id = batch_idx
            #pdb.set_trace()
            if args.cuda:
                data, target = data.cuda(), target.cuda(non_blocking=True)
            data, target = Variable(data), Variable(target)

            #print("recal:")
            #print(data[:2].data.cpu().numpy())

            def eval_closure():
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                return loss

            optimizer.recalibrate(batch_id, closure=eval_closure)

            if batch_idx % args.log_interval == 0:
                mid = timer()
                percent_done = 100. * batch_idx / len(train_loader)
                if percent_done > 0:
                    time_estimate = math.ceil((mid - start)*(100/percent_done))
                    time_estimate = str(datetime.timedelta(seconds=time_estimate))
                else:
                    time_estimate = "unknown"

                logging.info('Recal Epoch: {} [{}/{} ({:.0f}%)] estimate: {}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    percent_done, time_estimate))

    if hasattr(optimizer, "recalibrate_end"):
        optimizer.recalibrate_end()
    logging.info("Recalibration finished")