def step()

in recompute_svrg.py [0:0]


    def step(self, batch_id, closure):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """

        if self.epoch >= self.vr_from_epoch:
            self.store_running_mean()
            ## Store current xk, replace with x_tilde
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    xk = param_state['xk']
                    xk.zero_().add_(p.data)
                    p.data.zero_().add_(param_state['tilde_x'])

            # Standard is vr_bn_at_recalibration=True, so this doesn't fire.
            if not self.vr_bn_at_recalibration:
                self.model.eval() # turn off batch norm
            ## Compute gradient at x_tilde
            closure()

            ## Store x_tilde gradient in gi, and revert to xk
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    xk = param_state['xk']
                    gi = param_state['gi']
                    gi.zero_().add_(p.grad.data)
                    p.data.zero_().add_(xk)

            # Make sure batchnorm is handled correctly.
            self.restore_running_mean()

        ## compute gradient at xk
        loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            weight_decay = group['weight_decay']
            learning_rate = group['lr']

            for p in group['params']:
                gk = p.grad.data

                param_state = self.state[p]

                gi = param_state['gi']
                gavg = param_state['gavg']

                if momentum != 0:
                    buf = param_state['momentum_buffer']

                #########

                if self.epoch >= self.vr_from_epoch:
                    vr_gradient = gk.clone().sub_(gi).add_(gavg.type_as(gk))
                else:
                    vr_gradient = gk.clone() # Just do sgd steps

                if weight_decay != 0:
                    vr_gradient.add_(weight_decay, p.data)

                if momentum != 0:
                    dampening = 0.0
                    vr_gradient = buf.mul_(momentum).add_(1 - dampening, vr_gradient)

                # Take step.
                p.data.add_(-learning_rate, vr_gradient)


        # track number of minibatches seen
        self.batches_processed += 1

        return loss