in higher/optim.py [0:0]
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
if g.is_sparse:
raise RuntimeError(
'Adamax does not support sparse gradients'
)
state = self.state[group_idx][p_idx]
# State initialization
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = _torch.zeros_like(p.data)
state['exp_inf'] = _torch.zeros_like(p.data)
exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
beta1, beta2 = group['betas']
eps = group['eps']
state['step'] += 1
if group['weight_decay'] != 0:
g = _add(g, group['weight_decay'], p)
# Update biased first moment estimate
state['exp_avg'] = exp_avg = _add(
exp_avg.mul(beta1), 1 - beta1, g
)
# Update the exponentially weighted infinity norm.
state['exp_inf'] = exp_inf = exp_inf.mul(beta2).unsqueeze(0)
norm_buf = _torch.cat(
[exp_inf, _add(g.abs(), eps).unsqueeze(0)], 0
)
exp_inf, _ = _torch.max(norm_buf, 0, keepdim=False)
state['exp_inf'] = exp_inf
bias_correction = 1 - beta1**state['step']
clr = group['lr'] / bias_correction
group['params'][p_idx] = _addcdiv(p, -clr, exp_avg, exp_inf)