in higher/optim.py [0:0]
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
state = self.state[group_idx][p_idx]
state['step'] += 1
if group['weight_decay'] != 0:
if g.data.is_sparse:
raise RuntimeError(
"weight_decay option is not compatible with sparse "
"gradients"
)
g = _add(g, group['weight_decay'], p)
clr = group['lr'] / (
1 + (state['step'] - 1) * group['lr_decay']
)
if g.is_sparse:
# TODO: implement support for sparse gradients.
raise NotImplementedError(
"sparse gradient support for DifferentiableAdagrad not "
"implemented yet."
)
else:
state['sum'] = sum_ = _addcmul(state['sum'], 1, g, g)
mask = sum_ == 0.
_maybe_mask(sum_, mask)
std = _add(state['sum'].sqrt(), group['eps'] if 'eps' in group else 1e-10)
group['params'][p_idx] = _addcdiv(p, -clr, g, std)