in scsg.py [0:0]
def initialize(self):
for group in self.param_groups:
for p in group['params']:
momentum = group['momentum']
param_state = self.state[p]
if 'gavg' not in param_state:
param_state['gavg'] = p.data.clone().zero_()
param_state['gavg_debug'] = p.data.clone().zero_()
param_state['full_grad'] = p.data.clone().zero_()
param_state['gi'] = p.data.clone().zero_()
if not self.recompute_version:
gsize = p.data.size()
gtbl_size = torch.Size([self.megabatch_size] + list(gsize))
param_state['gktbl'] = torch.zeros(gtbl_size).cuda()
# m2 is the running gradient variance accumulator
param_state['m2'] = p.data.clone().zero_()
param_state['grad_running_cov'] = p.data.clone().zero_()
param_state['grad_running_var'] = p.data.clone().zero_()
param_state['grad_running_mean'] = p.data.clone().zero_()
if momentum != 0:
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = p.data.clone().zero_()
if 'tilde_x' not in param_state:
param_state['tilde_x'] = p.data.clone()
param_state['xk'] = p.data.clone()
state = self.model.state_dict()
# Batch norm's activation running_mean/var variables
for skey in state.keys():
if skey.endswith(".running_mean") or skey.endswith(".running_var"):
self.running_tmp[skey] = state[skey].clone()
logging.info("running: {}".format(self.running_tmp.keys()))