def step()

in timm/optim/adafactor_bv.py [0:0]


    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avg_sq_rs = []
            exp_avg_sq_cs = []
            exp_avg_sqs = []
            state_steps = []
            exp_avgs = []  # For momentum

            for p in group['params']:
                if p.grad is None:
                    continue

                if p.grad.is_sparse:
                    raise RuntimeError("Sparse gradients not supported")

                params_with_grad.append(p)
                grads.append(p.grad)

                state = self.state[p]

                if len(state) == 0:
                    # NOTE step on CPU, probably need some more though to make capturable
                    state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())

                    shape = p.grad.shape
                    factored_dims = _factored_dims(
                        shape,
                        factored=True,
                        min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
                    )

                    if factored_dims is not None:
                        dc, dr = factored_dims
                        row_shape = list(p.grad.shape)
                        row_shape[dr] = 1
                        col_shape = list(p.grad.shape)
                        col_shape[dc] = 1
                        state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
                        state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
                    else:
                        state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)

                    if self.defaults['momentum'] is not None:
                        state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])

                state_steps.append(state['step'])
                exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
                exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
                exp_avg_sqs.append(state.get('exp_avg_sq', None))
                exp_avgs.append(state.get('exp_avg', None))

            if group['foreach']:
                func = _multi_tensor_adafactor
            else:
                func = _single_tensor_adafactor

            func(
                params=params_with_grad,
                grads=grads,
                exp_avg_sq_rs=exp_avg_sq_rs,
                exp_avg_sq_cs=exp_avg_sq_cs,
                exp_avg_sqs=exp_avg_sqs,
                exp_avgs=exp_avgs,
                state_steps=state_steps,
                beta2_decay=group['decay_rate'],
                beta2_cap=group['beta2_cap'],
                min_dim_size_to_factor=group['min_dim_size_to_factor'],
                eps=group['eps'],
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                momentum=group['momentum'],
                momentum_dtype=group['momentum_dtype'],
                clipping_threshold=group['clipping_threshold'],
                unscaled_wd=group['unscaled_wd'],
                caution=group['caution'],
                max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
            )

        return loss