in run_vr.py [0:0]
def train_scsg(epoch, args, train_loader, test_loader, model, train_dataset, optimizer, criterion):
logging.info("Train (SCSG version)")
model.train()
data_buffer = []
inner_iters = optimizer.recalibration_interval
megabatch_size = optimizer.megabatch_size
optimizer.recalibration_i = 0
logged = False
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
# Store megabatch gradients
def outer_closure():
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
return loss
loss = optimizer.step_outer_part(closure=outer_closure, idx=len(data_buffer))
data_buffer.append((data, target))
# When data-buffer is full, do the actual inner steps.
if len(data_buffer) == megabatch_size:
for inner_i in range(inner_iters):
data, target = data_buffer[inner_i]
def eval_closure():
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.step_inner_part(closure=eval_closure, idx=inner_i)
data_buffer = []
optimizer.recalibration_i = 0
if not logged and args.log_diagnostics and epoch >= args.vr_from_epoch:
scsg_diagnostics(epoch, args, train_loader, optimizer, model, criterion)
logged = True
if batch_idx % args.log_interval == 0:
logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item()))
if hasattr(model, "sampler") and hasattr(model.sampler, "reorder"):
model.sampler.reorder()
if hasattr(train_dataset, "retransform"):
logging.info("retransform")
train_dataset.retransform()