in higher/optim.py [0:0]
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
if weight_decay != 0:
g = _add(g, weight_decay, p)
if momentum != 0:
param_state = self.state[group_idx][p_idx]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = g
else:
buf = param_state['momentum_buffer']
buf = _add(buf.mul(momentum), 1 - dampening, g)
param_state['momentum_buffer'] = buf
if nesterov:
g = _add(g, momentum, buf)
else:
g = buf
group['params'][p_idx] = _add(p, -group['lr'], g)