scsg.py [178:199]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            self.store_running_mean()

            ## Store current xk, replace with x_tilde
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    xk = param_state['xk']
                    xk.zero_().add_(p.data)
                    p.data.zero_().add_(param_state['tilde_x'])

            ## Compute gradient at x_tilde
            closure()

            ## Store x_tilde gradient in gi, and revert to xk
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    xk = param_state['xk']
                    param_state['gi'].zero_().add_(p.grad.data)
                    p.data.zero_().add_(xk)

            self.restore_running_mean()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



scsg.py [342:364]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        self.store_running_mean()

        ## Store current xk, replace with x_tilde
        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                xk = param_state['xk']
                xk.zero_().add_(p.data)
                p.data.zero_().add_(param_state['tilde_x'])

        ## Compute gradient at x_tilde
        closure()

        ## Store x_tilde gradient in gi, and revert to xk
        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                xk = param_state['xk']
                param_state['gi'].zero_().add_(p.grad.data)
                p.data.zero_().add_(xk)

        # Restore running_mean/var
        self.restore_running_mean()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



