def step()

in bitsandbytes/optim/lars.py [0:0]


    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            max_unorm = group['max_unorm']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None: continue

                state = self.state[p]
                d_p = p.grad
                if weight_decay != 0:
                    d_p = d_p.add(param, alpha=weight_decay)

                if momentum != 0:
                    buf = state.get('momentum_buffer', None)

                    if buf is None:
                        buf = torch.clone(d_p).detach()
                        state['momentum_buffer']= buf
                    else:
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

                    if nesterov:
                        update = d_p + buf*momentum
                    else:
                        update = buf

                update_scale = 1.0
                if max_unorm > 0.0:
                    assert p.dtype == torch.float32
                    pnorm = torch.norm(p.detach())
                    unorm = torch.norm(update)
                    if unorm > max_unorm*pnorm:
                        update_scale = max_unorm*pnorm/unorm

                p.add_(update, alpha=-lr*update_scale)

        return loss