def step()

in timm/optim/mars.py [0:0]


    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]
                # State initialization
                if len(state) <= 1:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p)
                    # Last Gradient
                    state['last_grad'] = torch.zeros_like(p)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p)

                state['step'] += 1
                step = state['step']
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                last_grad = state['last_grad']
                lr = group['lr']
                wd = group['weight_decay']
                beta1, beta2 = group['betas']
                is_grad_2d = grad.ndim >= 2

                # FIXME add multi-tensor (if usage warrants), make more standard
                _mars_single_tensor_step(
                    p,
                    grad,
                    exp_avg,
                    exp_avg_sq,
                    lr,
                    wd,
                    beta1,
                    beta2,
                    last_grad,
                    group['eps'],
                    step,
                    group['gamma'],
                    mars_type=group['mars_type'],
                    is_grad_2d=is_grad_2d,
                    optimize_1d=group['optimize_1d'],
                    lr_1d_factor=group['lr_1d_factor'],
                    betas_1d=group['betas_1d'],
                    caution=group['caution'],
                )

                state['last_grad'] = grad

        return loss