in timm/optim/adan.py [0:0]
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
try:
has_scalar_maximum = 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
except:
has_scalar_maximum = False
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
exp_avg_diffs = []
neg_pre_grads = []
beta1, beta2, beta3 = group['betas']
# assume same step across group now to simplify things
# per parameter step can be easily supported by making it a tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
bias_correction1 = 1.0 - beta1 ** group['step']
bias_correction2 = 1.0 - beta2 ** group['step']
bias_correction3 = 1.0 - beta3 ** group['step']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
if 'neg_pre_grad' not in state or group['step'] == 1:
state['neg_pre_grad'] = -p.grad.clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
exp_avg_diffs.append(state['exp_avg_diff'])
neg_pre_grads.append(state['neg_pre_grad'])
if not params_with_grad:
continue
if group['foreach'] is None:
use_foreach = not group['caution'] or has_scalar_maximum
else:
use_foreach = group['foreach']
if use_foreach:
func = _multi_tensor_adan
else:
func = _single_tensor_adan
func(
params_with_grad,
grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
exp_avg_diffs=exp_avg_diffs,
neg_pre_grads=neg_pre_grads,
beta1=beta1,
beta2=beta2,
beta3=beta3,
bias_correction1=bias_correction1,
bias_correction2=bias_correction2,
bias_correction3_sqrt=math.sqrt(bias_correction3),
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
no_prox=group['no_prox'],
caution=group['caution'],
)
return loss