def step_inner_part()

in scsg.py [0:0]


    def step_inner_part(self, closure, idx):
        # Check a few things:
        if self.recalibration_i != self.megabatch_size:
            raise Exception("bad self.recalibration_i: {}".format(self.recalibration_i))

        if self.recompute_version:
            self.store_running_mean()

            ## Store current xk, replace with x_tilde
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    xk = param_state['xk']
                    xk.zero_().add_(p.data)
                    p.data.zero_().add_(param_state['tilde_x'])

            ## Compute gradient at x_tilde
            closure()

            ## Store x_tilde gradient in gi, and revert to xk
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    xk = param_state['xk']
                    param_state['gi'].zero_().add_(p.grad.data)
                    p.data.zero_().add_(xk)

            self.restore_running_mean()

            # JUST FOR DEBUGGING
            if False:
                for group in self.param_groups:
                    for p in group['params']:
                        param_state = self.state[p]
                        gi = param_state['gi']
                        gi_tbl = param_state['gktbl'][idx, :]
                        #pdb.set_trace()
                        if torch.norm(gi-gi_tbl) > 1e-6:
                            print("difference: {}".format( torch.norm(gi-gi_tbl)))
                            pdb.set_trace()
        else:
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    param_state['gi'] = param_state['gktbl'][idx, :]

        ## compute gradient at xk
        loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            weight_decay = group['weight_decay']
            learning_rate = group['lr']

            for p in group['params']:
                gk = p.grad.data

                param_state = self.state[p]

                gi = param_state['gi']
                gavg = param_state['gavg']
                #gavg_debug = param_state['gavg_debug']

                if momentum != 0:
                    buf = param_state['momentum_buffer']

                #########

                if self.epoch < self.vr_from_epoch:
                    vr_gradient = gk.clone() # Just do sgd steps
                else:
                    vr_gradient = gk.clone().sub_(gi).add_(gavg)

                    # Track the running mean and variance of the gradients.
                    grad_running_mean = param_state['grad_running_mean']
                    grad_running_var = param_state['grad_running_var']
                    grad_running_cov = param_state['grad_running_cov']
                    cov_update = (gk - grad_running_mean)*(gi - gavg)
                    grad_running_cov.mul_(self.running_interp).add_(1-self.running_interp, cov_update)
                    # Using delta from before and after the mean update is apparently the
                    # best way to update variances.
                    delta1 = gk - grad_running_mean
                    grad_running_mean.mul_(self.running_interp).add_(1-self.running_interp, gk)
                    delta2 = gk - grad_running_mean
                    var_update = delta1*delta2
                    grad_running_var.mul_(self.running_interp).add_(1-self.running_interp, var_update)


                    #if torch.norm(gavg-gavg_debug) > 1e-7:
                    #    raise Exception("gavg norm diff: {}".format(torch.norm(gavg-gavg_debug)))

                if weight_decay != 0:
                    vr_gradient.add_(weight_decay, p.data)

                if momentum != 0:
                    dampening = 0.0
                    vr_gradient = buf.mul_(momentum).add_(1 - dampening, vr_gradient)

                # Take step.
                p.data.add_(-learning_rate, vr_gradient)


        # track number of minibatches seen
        #logging.info("interval i: {}".format(self.interval_i))
        self.batches_processed += 1

        if self.batches_processed % 20 == 0 and self.batches_processed > 0:
            running_cov_acum = 0.0
            m2_acum = 0.0
            var_acum = 0.0
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    grad_running_cov = param_state['grad_running_cov']
                    grad_running_var = param_state['grad_running_var']
                    m2 = param_state['m2']
                    running_cov_acum += grad_running_cov.sum()
                    var_acum += grad_running_var.sum()
                    # m2 is not stored normalized by self.nbatches
                    m2_norm = m2.div(self.megabatch_size)
                    m2_acum += m2_norm.sum()

            if m2_acum > 0:
                cov_var_ratio = running_cov_acum/m2_acum

                vr_variance = var_acum + m2_acum - 2*running_cov_acum
                vr_ratio = vr_variance/var_acum
                corr_coef = running_cov_acum/math.sqrt(var_acum*m2_acum)
                logging.info("VR RATIO: {:.3f}. Raw cov/var: {:.3f}, correlation coef: {:.3f}. Var: {:.3f} m2: {:.3f} cov: {:.3f}".format(
                    vr_ratio, cov_var_ratio, corr_coef, var_acum, m2_acum, running_cov_acum))
        return loss