in recompute_svrg.py [0:0]
def initialize(self):
for group in self.param_groups:
for p in group['params']:
momentum = group['momentum']
param_state = self.state[p]
if 'gavg' not in param_state:
param_state['gavg'] = p.data.double().clone().zero_()
param_state['gi'] = p.data.clone().zero_()
param_state['gi_debug'] = p.data.clone().zero_()
if momentum != 0:
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = p.data.clone().zero_()
if 'tilde_x' not in param_state:
param_state['tilde_x'] = p.data.clone()
param_state['xk'] = p.data.clone()
# Batch norm's activation running_mean/var variables
state = self.model.state_dict()
for skey in state.keys():
if skey.endswith(".running_mean") or skey.endswith(".running_var"):
self.running_tmp[skey] = state[skey].clone()