def initialize()

in scsg.py [0:0]


    def initialize(self):
        for group in self.param_groups:
            for p in group['params']:
                momentum = group['momentum']

                param_state = self.state[p]

                if 'gavg' not in param_state:
                    param_state['gavg'] =  p.data.clone().zero_()
                    param_state['gavg_debug'] =  p.data.clone().zero_()
                    param_state['full_grad'] =  p.data.clone().zero_()
                    param_state['gi'] = p.data.clone().zero_()

                    if not self.recompute_version:
                        gsize = p.data.size()
                        gtbl_size = torch.Size([self.megabatch_size] + list(gsize))
                        param_state['gktbl'] = torch.zeros(gtbl_size).cuda()

                    # m2 is the running gradient variance accumulator
                    param_state['m2'] = p.data.clone().zero_()

                    param_state['grad_running_cov'] = p.data.clone().zero_()
                    param_state['grad_running_var'] = p.data.clone().zero_()
                    param_state['grad_running_mean'] = p.data.clone().zero_()

                if momentum != 0:
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = p.data.clone().zero_()

                if 'tilde_x' not in param_state:
                    param_state['tilde_x'] = p.data.clone()
                    param_state['xk'] = p.data.clone()


        state = self.model.state_dict()
        # Batch norm's activation running_mean/var variables
        for skey in state.keys():
            if skey.endswith(".running_mean") or skey.endswith(".running_var"):
                self.running_tmp[skey] = state[skey].clone()

        logging.info("running: {}".format(self.running_tmp.keys()))