in torch_svrg.py [0:0]
def logging_pass_end(self, batch_idx):
m = self.nbatches
logging.info("logging diagnostics computation")
gradient_sqs = []
vr_step_sqs = []
forth_sqs = []
dist_sq_acum = 0.0
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
tilde_x = param_state['tilde_x']
iterate_diff = p.data - tilde_x
dist_sq_acum += iterate_diff.norm()**2 #torch.dot(iterate_diff,iterate_diff)
dist = math.sqrt(dist_sq_acum)
for batch_id in range(m):
grad_norm_acum = 0.0
vr_norm_acum = 0.0
forth_acum = 0.0
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
gktbl = param_state['gktbl']
gavg = param_state['gavg'].type_as(p.data).cpu()
gi = gktbl[batch_id, :].type_as(p.data).cpu()
# Logging versions are at current location xk,
# compared to gavg/tktbl which are at xtilde
logging_gktbl = param_state['logging_gktbl']
logging_gavg = param_state['logging_gavg'].type_as(p.data).cpu()
logging_gi = logging_gktbl[batch_id, :].type_as(p.data).cpu()
vr_step = (logging_gi - gi + gavg) - logging_gavg
gi_step = logging_gi - logging_gavg
grad_norm_acum += gi_step.pow(2.0).sum().item()
vr_norm_acum += vr_step.pow(2.0).sum().item()
forth_acum += gi_step.pow(2.0).sum().item()
gradient_sqs.append(grad_norm_acum)
vr_step_sqs.append(vr_norm_acum)
forth_sqs.append(forth_acum**2)
# Compute variance numbers
gradient_variance = sum(gradient_sqs)/m
fourth_moment = sum(forth_sqs)/m - gradient_variance**2
vr_step_variance = sum(vr_step_sqs)/m
logging.info("gradient variance: {} vr: {}, ratio vr/g: {}".format(
gradient_variance, vr_step_variance, vr_step_variance/gradient_variance))
logging.info(f"forth: {fourth_moment} relative std: {math.sqrt(fourth_moment)/gradient_variance} rel SE: {math.sqrt(fourth_moment/m)/gradient_variance}")
logging.info("self.logging_evals: {}".format(self.logging_evals))
#pdb.set_trace()
self.gradient_variances.append(gradient_variance)
self.vr_step_variances.append(vr_step_variance)
self.batch_indices.append(batch_idx)
self.iterate_distances.append(dist)