in timm/optim/mars.py [0:0]
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) <= 1:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
# Last Gradient
state['last_grad'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
state['step'] += 1
step = state['step']
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
last_grad = state['last_grad']
lr = group['lr']
wd = group['weight_decay']
beta1, beta2 = group['betas']
is_grad_2d = grad.ndim >= 2
# FIXME add multi-tensor (if usage warrants), make more standard
_mars_single_tensor_step(
p,
grad,
exp_avg,
exp_avg_sq,
lr,
wd,
beta1,
beta2,
last_grad,
group['eps'],
step,
group['gamma'],
mars_type=group['mars_type'],
is_grad_2d=is_grad_2d,
optimize_1d=group['optimize_1d'],
lr_1d_factor=group['lr_1d_factor'],
betas_1d=group['betas_1d'],
caution=group['caution'],
)
state['last_grad'] = grad
return loss