def _update()

in higher/optim.py [0:0]


    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:

        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue

                if g.is_sparse:
                    raise RuntimeError('ASGD does not support sparse gradients')
                state = self.state[group_idx][p_idx]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['eta'] = group['lr']
                    state['mu'] = 1
                    state['ax'] = _torch.zeros_like(p.data)

                state['step'] += 1

                if group['weight_decay'] != 0:
                    g = _add(g, group['weight_decay'], p)

                # decay term
                p = p.mul(1 - group['lambd'] * state['eta'])

                # update parameter
                group['params'][p_idx] = _add(p, -state['eta'], g)

                # averaging
                if state['mu'] != 1:
                    state['ax'] = _add(
                        state['ax'],
                        p.sub(state['ax']).mul(state['mu'])
                    )
                else:
                    state['ax'] = p

                # update eta and mu
                state['eta'] = (
                    group['lr'] / _math.pow(
                    (1 + group['lambd'] * group['lr'] * state['step']),
                    group['alpha']
                    )   
                )
                state['mu'] = 1 / max(1, state['step'] - group['t0'])