in higher/optim.py [0:0]
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
if g.is_sparse:
raise RuntimeError(
'Rprop does not support sparse gradients'
)
state = self.state[group_idx][p_idx]
# State initialization
if len(state) == 0:
state['step'] = 0
state['prev'] = _torch.zeros_like(p.data)
state['step_size'] = g.new().resize_as_(g).fill_(
group['lr']
)
etaminus, etaplus = group['etas']
step_size_min, step_size_max = group['step_sizes']
step_size = state['step_size']
state['step'] += 1
sign = g.mul(state['prev']).sign()
sign[sign.gt(0)] = etaplus
sign[sign.lt(0)] = etaminus
sign[sign.eq(0)] = 1
# update stepsizes with step size updates
step_size = step_size.mul(sign).clamp(
step_size_min, step_size_max
)
state['step_size'] = step_size
# for dir<0, dfdx=0
# for dir>=0 dfdx=dfdx
g = _torch.where(sign.eq(etaminus), _torch.zeros_like(g), g)
# update parameters
group['params'][p_idx] = _addcmul(p, -1, g.sign(), step_size)
state['prev'] = g.clone()