in scsg.py [0:0]
def step_inner_part(self, closure, idx):
# Check a few things:
if self.recalibration_i != self.megabatch_size:
raise Exception("bad self.recalibration_i: {}".format(self.recalibration_i))
if self.recompute_version:
self.store_running_mean()
## Store current xk, replace with x_tilde
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
xk = param_state['xk']
xk.zero_().add_(p.data)
p.data.zero_().add_(param_state['tilde_x'])
## Compute gradient at x_tilde
closure()
## Store x_tilde gradient in gi, and revert to xk
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
xk = param_state['xk']
param_state['gi'].zero_().add_(p.grad.data)
p.data.zero_().add_(xk)
self.restore_running_mean()
# JUST FOR DEBUGGING
if False:
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
gi = param_state['gi']
gi_tbl = param_state['gktbl'][idx, :]
#pdb.set_trace()
if torch.norm(gi-gi_tbl) > 1e-6:
print("difference: {}".format( torch.norm(gi-gi_tbl)))
pdb.set_trace()
else:
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['gi'] = param_state['gktbl'][idx, :]
## compute gradient at xk
loss = closure()
for group in self.param_groups:
momentum = group['momentum']
weight_decay = group['weight_decay']
learning_rate = group['lr']
for p in group['params']:
gk = p.grad.data
param_state = self.state[p]
gi = param_state['gi']
gavg = param_state['gavg']
#gavg_debug = param_state['gavg_debug']
if momentum != 0:
buf = param_state['momentum_buffer']
#########
if self.epoch < self.vr_from_epoch:
vr_gradient = gk.clone() # Just do sgd steps
else:
vr_gradient = gk.clone().sub_(gi).add_(gavg)
# Track the running mean and variance of the gradients.
grad_running_mean = param_state['grad_running_mean']
grad_running_var = param_state['grad_running_var']
grad_running_cov = param_state['grad_running_cov']
cov_update = (gk - grad_running_mean)*(gi - gavg)
grad_running_cov.mul_(self.running_interp).add_(1-self.running_interp, cov_update)
# Using delta from before and after the mean update is apparently the
# best way to update variances.
delta1 = gk - grad_running_mean
grad_running_mean.mul_(self.running_interp).add_(1-self.running_interp, gk)
delta2 = gk - grad_running_mean
var_update = delta1*delta2
grad_running_var.mul_(self.running_interp).add_(1-self.running_interp, var_update)
#if torch.norm(gavg-gavg_debug) > 1e-7:
# raise Exception("gavg norm diff: {}".format(torch.norm(gavg-gavg_debug)))
if weight_decay != 0:
vr_gradient.add_(weight_decay, p.data)
if momentum != 0:
dampening = 0.0
vr_gradient = buf.mul_(momentum).add_(1 - dampening, vr_gradient)
# Take step.
p.data.add_(-learning_rate, vr_gradient)
# track number of minibatches seen
#logging.info("interval i: {}".format(self.interval_i))
self.batches_processed += 1
if self.batches_processed % 20 == 0 and self.batches_processed > 0:
running_cov_acum = 0.0
m2_acum = 0.0
var_acum = 0.0
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
grad_running_cov = param_state['grad_running_cov']
grad_running_var = param_state['grad_running_var']
m2 = param_state['m2']
running_cov_acum += grad_running_cov.sum()
var_acum += grad_running_var.sum()
# m2 is not stored normalized by self.nbatches
m2_norm = m2.div(self.megabatch_size)
m2_acum += m2_norm.sum()
if m2_acum > 0:
cov_var_ratio = running_cov_acum/m2_acum
vr_variance = var_acum + m2_acum - 2*running_cov_acum
vr_ratio = vr_variance/var_acum
corr_coef = running_cov_acum/math.sqrt(var_acum*m2_acum)
logging.info("VR RATIO: {:.3f}. Raw cov/var: {:.3f}, correlation coef: {:.3f}. Var: {:.3f} m2: {:.3f} cov: {:.3f}".format(
vr_ratio, cov_var_ratio, corr_coef, var_acum, m2_acum, running_cov_acum))
return loss