def _update()

in higher/optim.py [0:0]


    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:

        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue

                state = self.state[group_idx][p_idx]

                state['step'] += 1

                if group['weight_decay'] != 0:
                    if g.data.is_sparse:
                        raise RuntimeError(
                            "weight_decay option is not compatible with sparse "
                            "gradients"
                        )
                    g = _add(g, group['weight_decay'], p)

                clr = group['lr'] / (
                    1 + (state['step'] - 1) * group['lr_decay']
                )

                if g.is_sparse:
                    # TODO: implement support for sparse gradients.
                    raise NotImplementedError(
                        "sparse gradient support for DifferentiableAdagrad not "
                        "implemented yet."
                    )
                else:
                    state['sum'] = sum_ = _addcmul(state['sum'], 1, g, g)
                    mask = sum_ == 0.
                    _maybe_mask(sum_, mask)
                    std = _add(state['sum'].sqrt(), group['eps'] if 'eps' in group else 1e-10)
                    group['params'][p_idx] = _addcdiv(p, -clr, g, std)