in recompute_svrg.py [0:0]
def step(self, batch_id, closure):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
if self.epoch >= self.vr_from_epoch:
self.store_running_mean()
## Store current xk, replace with x_tilde
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
xk = param_state['xk']
xk.zero_().add_(p.data)
p.data.zero_().add_(param_state['tilde_x'])
# Standard is vr_bn_at_recalibration=True, so this doesn't fire.
if not self.vr_bn_at_recalibration:
self.model.eval() # turn off batch norm
## Compute gradient at x_tilde
closure()
## Store x_tilde gradient in gi, and revert to xk
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
xk = param_state['xk']
gi = param_state['gi']
gi.zero_().add_(p.grad.data)
p.data.zero_().add_(xk)
# Make sure batchnorm is handled correctly.
self.restore_running_mean()
## compute gradient at xk
loss = closure()
for group in self.param_groups:
momentum = group['momentum']
weight_decay = group['weight_decay']
learning_rate = group['lr']
for p in group['params']:
gk = p.grad.data
param_state = self.state[p]
gi = param_state['gi']
gavg = param_state['gavg']
if momentum != 0:
buf = param_state['momentum_buffer']
#########
if self.epoch >= self.vr_from_epoch:
vr_gradient = gk.clone().sub_(gi).add_(gavg.type_as(gk))
else:
vr_gradient = gk.clone() # Just do sgd steps
if weight_decay != 0:
vr_gradient.add_(weight_decay, p.data)
if momentum != 0:
dampening = 0.0
vr_gradient = buf.mul_(momentum).add_(1 - dampening, vr_gradient)
# Take step.
p.data.add_(-learning_rate, vr_gradient)
# track number of minibatches seen
self.batches_processed += 1
return loss