in timm/optim/madgrad.py [0:0]
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
eps = group['eps']
lr = group['lr'] + eps
weight_decay = group['weight_decay']
momentum = group['momentum']
ck = 1 - momentum
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if momentum != 0.0 and grad.is_sparse:
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['grad_sum_sq'] = torch.zeros_like(p)
state['s'] = torch.zeros_like(p)
if momentum != 0:
state['x0'] = torch.clone(p).detach()
state['step'] += 1
grad_sum_sq = state['grad_sum_sq']
s = state['s']
lamb = lr * math.sqrt(state['step'])
# Apply weight decay
if weight_decay != 0:
if group['decoupled_decay']:
p.mul_(1.0 - group['lr'] * weight_decay)
else:
if grad.is_sparse:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
grad.add_(p, alpha=weight_decay)
if grad.is_sparse:
grad = grad.coalesce()
grad_val = grad._values()
p_masked = p.sparse_mask(grad)
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
s_masked = s.sparse_mask(grad)
# Compute x_0 from other known quantities
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
# Dense + sparse op
grad_sq = grad * grad
grad_sum_sq.add_(grad_sq, alpha=lamb)
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
s.add_(grad, alpha=lamb)
s_masked._values().add_(grad_val, alpha=lamb)
# update masked copy of p
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
# Copy updated masked p to dense p using an add operation
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
p.add_(p_masked, alpha=-1)
else:
if momentum == 0:
# Compute x_0 from other known quantities
rms = grad_sum_sq.pow(1 / 3).add_(eps)
x0 = p.addcdiv(s, rms, value=1)
else:
x0 = state['x0']
# Accumulate second moments
grad_sum_sq.addcmul_(grad, grad, value=lamb)
rms = grad_sum_sq.pow(1 / 3).add_(eps)
# Update s
s.add_(grad, alpha=lamb)
# Step
if momentum == 0:
p.copy_(x0.addcdiv(s, rms, value=-1))
else:
z = x0.addcdiv(s, rms, value=-1)
# p is a moving average of z
p.mul_(1 - ck).add_(z, alpha=ck)
return loss