def step()

in timm/optim/kron.py [0:0]


    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        total_momentum_size = 0
        total_momentum_mb = 0
        total_precond_size = 0
        total_precond_mb = 0

        for group in self.param_groups:
            mu_dtype = group.get("mu_dtype")
            precond_dtype = group.get("precond_dtype", torch.float32)
            momentum_into_precond_update = group.get("momentum_into_precond_update", True)
            update_prob = group.get("preconditioner_update_probability", None)

            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad
                state = self.state[p]

                flattened = False
                if group['flatten']:
                    grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
                    flattened = True

                if len(state) == 0:
                    state["step"] = 0
                    state["update_counter"] = 0
                    state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
                    # init Q and einsum expressions on first step
                    state["Q"], exprs = _init_Q_exprs(
                        grad,
                        group["precond_init_scale"],
                        group["max_size_triangular"],
                        group["min_ndim_triangular"],
                        group["memory_save_mode"],
                        dtype=precond_dtype,
                    )
                    self._param_exprs[p] = exprs

                    # Accumulate sizes for log
                    momentum_size = state["momentum_buffer"].numel()
                    momentum_mb = momentum_size * state["momentum_buffer"].element_size() / 2**20
                    total_momentum_size += momentum_size
                    total_momentum_mb += momentum_mb

                    precond_size = sum(q.numel() for q in state["Q"])
                    precond_mb = sum(q.numel() * q.element_size() for q in state["Q"]) / 2**20
                    total_precond_size += precond_size
                    total_precond_mb += precond_mb
                elif p not in self._param_exprs:
                    # init only the einsum expressions, called after state load, Q are loaded from state_dict
                    exprs = _init_Q_exprs(
                        grad,
                        group["precond_init_scale"],
                        group["max_size_triangular"],
                        group["min_ndim_triangular"],
                        group["memory_save_mode"],
                        dtype=precond_dtype,
                        init_q=False,
                    )
                    self._param_exprs[p] = exprs
                else:
                    # retrieve cached expressions
                    exprs = self._param_exprs[p]

                # update preconditioners all together deterministically
                if update_prob is None:
                    update_prob = precond_update_prob_schedule
                if callable(update_prob):
                    update_prob = update_prob(state["step"])
                state["update_counter"] += 1
                do_update = state["update_counter"] >= 1 / update_prob
                if do_update:
                    state["update_counter"] = 0

                state["step"] += 1

                # Update momentum buffer
                beta = group["momentum"]
                bias_correction = 1 - beta ** state["step"]
                momentum_buffer = state["momentum_buffer"]
                momentum_buffer.mul_(group["momentum"]).add_(grad, alpha=1 - group["momentum"])

                # Restore momentum dtype
                if mu_dtype is not None:
                    momentum_buffer.copy_(momentum_buffer.to(dtype=mu_dtype))
                debiased_momentum = (momentum_buffer / bias_correction).to(dtype=precond_dtype)

                # Balance preconditioners roughly every 100 updates
                balance = self.rng.random() < 0.01 and do_update
                if grad.dim() > 1 and balance:
                    self._balance_Q(state["Q"])

                # Update preconditioner
                if do_update:
                    exprA, exprGs, _ = exprs
                    Q = state["Q"]
                    if self.deterministic:
                        torch_rng = torch.Generator(device=debiased_momentum.device)
                        torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
                    else:
                        torch_rng = None
                    V = torch.randn(
                        debiased_momentum.shape,
                        generator=torch_rng,
                        dtype=precond_dtype,
                        device=debiased_momentum.device,
                    )
                    G = debiased_momentum if momentum_into_precond_update else grad

                    A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)

                    terms = self._q_terms(exprGs, A, conjB)

                    for q, (term1, term2) in zip(Q, terms):
                        tmp = term1 - term2
                        tmp *= group["precond_lr"]
                        if q.dim() < 2:
                            tmp *= q
                            tmp /= (term1 + term2).norm(float("inf")) + self._tiny
                        else:
                            tmp = torch.triu(tmp)
                            tmp /= _norm_lower_bound(term1 + term2) + self._tiny
                            tmp @= q
                        q.sub_(tmp)

                # Precondition gradients
                pre_grad = self._precond_grad(
                    state["Q"],
                    exprs,
                    debiased_momentum,
                ).to(dtype=p.dtype)

                # RMS of pre_grad should be 1.0, so let's cap at 1.1
                pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
                if flattened:
                    pre_grad = pre_grad.view(p.shape)

                # Apply weight decay
                weight_decay = group["weight_decay"]
                if weight_decay != 0:
                    if group["stochastic_weight_decay"]:
                        weight_decay = 2 * self.rng.random() * weight_decay

                    if group["decoupled_decay"]:
                        if group['corrected_weight_decay']:
                            wd_scale = group["lr"] ** 2 / self.defaults['lr']
                        else:
                            wd_scale = group["lr"]
                        p.mul_(1. - wd_scale * weight_decay)
                    else:
                        pre_grad.add_(p, alpha=weight_decay)

                # Update parameters
                p.add_(pre_grad, alpha=-group["lr"])

        if total_momentum_size > 0:
            _logger.info(f"PSGD Momentum buffer size: {total_momentum_size} elements, {total_momentum_mb:.2f} MB")
            _logger.info(f"PSGD Preconditioners size: {total_precond_size} elements, {total_precond_mb:.2f} MB")

        return loss