def train_scsg()

in run_vr.py [0:0]


def train_scsg(epoch, args, train_loader, test_loader, model, train_dataset, optimizer, criterion):
    logging.info("Train (SCSG version)")
    model.train()

    data_buffer = []
    inner_iters = optimizer.recalibration_interval
    megabatch_size = optimizer.megabatch_size
    optimizer.recalibration_i = 0
    logged = False

    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        # Store megabatch gradients
        def outer_closure():
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            return loss

        loss = optimizer.step_outer_part(closure=outer_closure, idx=len(data_buffer))
        data_buffer.append((data, target))

        # When data-buffer is full, do the actual inner steps.
        if len(data_buffer) == megabatch_size:

            for inner_i in range(inner_iters):
                data, target = data_buffer[inner_i]

                def eval_closure():
                    optimizer.zero_grad()
                    output = model(data)
                    loss = criterion(output, target)
                    loss.backward()
                    return loss

                optimizer.step_inner_part(closure=eval_closure, idx=inner_i)

            data_buffer = []
            optimizer.recalibration_i = 0

            if not logged and args.log_diagnostics and epoch >= args.vr_from_epoch:
                scsg_diagnostics(epoch, args, train_loader, optimizer, model, criterion)
                logged = True


        if batch_idx % args.log_interval == 0:
            logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data.item()))


    if hasattr(model, "sampler") and hasattr(model.sampler, "reorder"):
        model.sampler.reorder()
    if hasattr(train_dataset, "retransform"):
        logging.info("retransform")
        train_dataset.retransform()