def step()

in torch_svrg.py [0:0]


    def step(self, batch_id, closure):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = closure()
        dist_sq_acum = 0.0
        grad_dist_sq_acum = 0.0

        #print("step loss: ", loss)

        for group in self.param_groups:
            momentum = group['momentum']
            weight_decay = group['weight_decay']
            learning_rate = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue
                gk = p.grad.data

                param_state = self.state[p]

                gktbl = param_state['gktbl']
                gavg = param_state['gavg'].type_as(p.data)
                tilde_x = param_state['tilde_x']

                if momentum != 0:
                    buf = param_state['momentum_buffer']

                #########

                if self.epoch < self.vr_from_epoch:
                    vr_gradient = gk.clone() # Just do sgd steps
                else:
                    gi = gktbl[batch_id, :].cuda()

                    vr_gradient = gk.clone().sub_(gi - gavg)

                    # Some diagnostics
                    iterate_diff = p.data - tilde_x
                    #pdb.set_trace()
                    dist_sq_acum += iterate_diff.norm()**2 #torch.dot(iterate_diff,iterate_diff)
                    grad_diff = gi - gk
                    grad_dist_sq_acum += grad_diff.norm()**2 #torch.dot(grad_diff,grad_diff)

                if weight_decay != 0:
                    vr_gradient.add_(weight_decay, p.data)

                if momentum != 0:
                    dampening = 0.0
                    vr_gradient = buf.mul_(momentum).add_(1 - dampening, vr_gradient)

                # Take step.
                p.data.add_(-learning_rate, vr_gradient)

                # Update running iterate mean:
                param_state['running_x'].mul_(self.running_interp).add_(1-self.running_interp, p.data)

        # track number of minibatches seen
        self.batches_processed += 1

        dist = math.sqrt(dist_sq_acum)
        grad_dist = math.sqrt(grad_dist_sq_acum)

        self.inrun_iterate_distances.append(dist)
        self.inrun_grad_distances.append(grad_dist)

        return loss