in higher/optim.py [0:0]
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
weight_decay = group['weight_decay']
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
state = self.state[group_idx][p_idx]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = _torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = _torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. mov. avg. of sq. grad. vals
state['max_exp_avg_sq'] = _torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
state['step'] += 1
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
if weight_decay != 0:
g = g + (weight_decay * p)
# Decay the first and second moment running average coefficient
state['exp_avg'] = exp_avg = (exp_avg * beta1) + (1 - beta1) * g
state['exp_avg_sq'] = exp_avg_sq = (
(exp_avg_sq * beta2) + (1 - beta2) * g * g
)
# Deal with stability issues
mask = exp_avg_sq == 0.
_maybe_mask(exp_avg_sq, mask)
if amsgrad:
# Maintains the max of all 2nd moment running avg. till now
state['max_exp_avg_sq'] = max_exp_avg_sq = _torch.max(
max_exp_avg_sq, exp_avg_sq
)
# Use the max. for normalizing running avg. of gradient
denom = _add(
max_exp_avg_sq.sqrt() / _math.sqrt(bias_correction2),
group['eps']
)
else:
denom = _add(
exp_avg_sq.sqrt() / _math.sqrt(bias_correction2),
group['eps']
)
step_size = group['lr'] / bias_correction1
group['params'][p_idx] = _addcdiv(
p, -step_size, exp_avg, denom
)