in higher/optim.py [0:0]
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
if g.is_sparse:
raise RuntimeError(
'RMSprop does not support sparse gradients'
)
state = self.state[group_idx][p_idx]
# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = _torch.zeros_like(p.data)
if group['momentum'] > 0:
state['momentum_buffer'] = _torch.zeros_like(p.data)
if group['centered']:
state['grad_avg'] = _torch.zeros_like(p.data)
square_avg = state['square_avg']
alpha = group['alpha']
state['step'] += 1
if group['weight_decay'] != 0:
g = _add(g, group['weight_decay'], p)
square_avg = _addcmul(square_avg.mul(alpha), 1 - alpha, g, g)
state['square_avg'] = square_avg
# NB: This prevents nans but is not sufficient to recover
# correct gradients.
mask = square_avg == 0.
_maybe_mask(square_avg, mask)
if group['centered']:
grad_avg = state['grad_avg']
grad_avg = _add(grad_avg.mul(alpha), 1 - alpha, g)
state['grad_avg'] = grad_avg
eps = group['eps']
avg = _add(
_addcmul(square_avg, -1, grad_avg, grad_avg).sqrt(), eps
)
else:
avg = _add(square_avg.sqrt(), group['eps'])
if group['momentum'] > 0:
buf = state['momentum_buffer']
buf = _addcdiv(buf.mul(group['momentum']), g, avg)
state['momentum_buffer'] = buf
p = _add(p, -group['lr'], buf)
else:
p = _addcdiv(p, -group['lr'], g, avg)
group['params'][p_idx] = p