def step()

in timm/optim/madgrad.py [0:0]


    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            eps = group['eps']
            lr = group['lr'] + eps
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            ck = 1 - momentum

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                if momentum != 0.0 and grad.is_sparse:
                    raise RuntimeError("momentum != 0 is not compatible with sparse gradients")

                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['grad_sum_sq'] = torch.zeros_like(p)
                    state['s'] = torch.zeros_like(p)
                    if momentum != 0:
                        state['x0'] = torch.clone(p).detach()

                state['step'] += 1
                grad_sum_sq = state['grad_sum_sq']
                s = state['s']
                lamb = lr * math.sqrt(state['step'])

                # Apply weight decay
                if weight_decay != 0:
                    if group['decoupled_decay']:
                        p.mul_(1.0 - group['lr'] * weight_decay)
                    else:
                        if grad.is_sparse:
                            raise RuntimeError("weight_decay option is not compatible with sparse gradients")
                        grad.add_(p, alpha=weight_decay)

                if grad.is_sparse:
                    grad = grad.coalesce()
                    grad_val = grad._values()

                    p_masked = p.sparse_mask(grad)
                    grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
                    s_masked = s.sparse_mask(grad)

                    # Compute x_0 from other known quantities
                    rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
                    x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)

                    # Dense + sparse op
                    grad_sq = grad * grad
                    grad_sum_sq.add_(grad_sq, alpha=lamb)
                    grad_sum_sq_masked.add_(grad_sq, alpha=lamb)

                    rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)

                    s.add_(grad, alpha=lamb)
                    s_masked._values().add_(grad_val, alpha=lamb)

                    # update masked copy of p
                    p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
                    # Copy updated masked p to dense p using an add operation
                    p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
                    p.add_(p_masked, alpha=-1)
                else:
                    if momentum == 0:
                        # Compute x_0 from other known quantities
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
                        x0 = p.addcdiv(s, rms, value=1)
                    else:
                        x0 = state['x0']

                    # Accumulate second moments
                    grad_sum_sq.addcmul_(grad, grad, value=lamb)
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)

                    # Update s
                    s.add_(grad, alpha=lamb)

                    # Step
                    if momentum == 0:
                        p.copy_(x0.addcdiv(s, rms, value=-1))
                    else:
                        z = x0.addcdiv(s, rms, value=-1)

                        # p is a moving average of z
                        p.mul_(1 - ck).add_(z, alpha=ck)

        return loss