in diagnostics.py [0:0]
def online_svrg_diagnostics(epoch, batch_idx, args, train_loader, optimizer, model, criterion):
if (epoch == 1 or (epoch % 10) == 0) and optimizer.epoch >= optimizer.vr_from_epoch and batch_idx == 0:
nbatches = len(train_loader)
mega_batch_size = optimizer.megabatch_size
recalibration_interval = optimizer.recalibration_interval
#print("interval, interval = {}".format(interval))
optimizer.logging_pass_start()
# Compute the snapshot
snapshot_i = 0
for inner_batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
def eval_closure():
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.snapshot_pass(inner_batch_idx, closure=eval_closure)
snapshot_i += 1
if snapshot_i == mega_batch_size:
break
logging.info("Snapshot computed")
for interval in range(recalibration_interval):
logging.info("Interval: {}, recal_i: {}".format(interval, optimizer.recalibration_i))
optimizer.full_grad_init()
# Do a full gradient calculation:
for inner_batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
def eval_closure():
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.full_grad_calc(inner_batch_idx, closure=eval_closure)
logging.info("Full grad calculation finished")
for inner_batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
def eval_closure():
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.logging_pass(interval, inner_batch_idx, closure=eval_closure)
logging.info("Logging pass finished")
# Take a single step at the end to progress in the interval
# Using whatever minibatch was last in the stats logging pass
optimizer.step(inner_batch_idx, closure=eval_closure)