in bitsandbytes/optim/optimizer.py [0:0]
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
overflows = []
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self.init_state(group, p, gindex, pindex)
self.update_step(group, p, gindex, pindex)
return loss