def _update()

in higher/optim.py [0:0]


    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:

        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue

                if g.is_sparse:
                    raise RuntimeError(
                        'RMSprop does not support sparse gradients'
                    )
                state = self.state[group_idx][p_idx]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['square_avg'] = _torch.zeros_like(p.data)
                    if group['momentum'] > 0:
                        state['momentum_buffer'] = _torch.zeros_like(p.data)
                    if group['centered']:
                        state['grad_avg'] = _torch.zeros_like(p.data)

                square_avg = state['square_avg']
                alpha = group['alpha']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    g = _add(g, group['weight_decay'], p)

                square_avg = _addcmul(square_avg.mul(alpha), 1 - alpha, g, g)
                state['square_avg'] = square_avg

                # NB: This prevents nans but is not sufficient to recover
                # correct gradients.
                mask = square_avg == 0.
                _maybe_mask(square_avg, mask)

                if group['centered']:
                    grad_avg = state['grad_avg']
                    grad_avg = _add(grad_avg.mul(alpha), 1 - alpha, g)
                    state['grad_avg'] = grad_avg
                    eps = group['eps']
                    avg = _add(
                        _addcmul(square_avg, -1, grad_avg, grad_avg).sqrt(), eps
                    )
                else:
                    avg = _add(square_avg.sqrt(), group['eps'])

                if group['momentum'] > 0:
                    buf = state['momentum_buffer']
                    buf = _addcdiv(buf.mul(group['momentum']), g, avg)
                    state['momentum_buffer'] = buf
                    p = _add(p, -group['lr'], buf)
                else:
                    p = _addcdiv(p, -group['lr'], g, avg)

                group['params'][p_idx] = p