def _update()

in higher/optim.py [0:0]


    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:

        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue

                if g.is_sparse:
                    raise RuntimeError(
                        'Adamax does not support sparse gradients'
                    )

                state = self.state[group_idx][p_idx]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = _torch.zeros_like(p.data)
                    state['exp_inf'] = _torch.zeros_like(p.data)

                exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
                beta1, beta2 = group['betas']
                eps = group['eps']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    g = _add(g, group['weight_decay'], p)

                # Update biased first moment estimate
                state['exp_avg'] = exp_avg = _add(
                    exp_avg.mul(beta1), 1 - beta1, g
                )
                # Update the exponentially weighted infinity norm.
                state['exp_inf'] = exp_inf = exp_inf.mul(beta2).unsqueeze(0)
                norm_buf = _torch.cat(
                    [exp_inf, _add(g.abs(), eps).unsqueeze(0)], 0
                )
                exp_inf, _ = _torch.max(norm_buf, 0, keepdim=False)
                state['exp_inf'] = exp_inf

                bias_correction = 1 - beta1**state['step']
                clr = group['lr'] / bias_correction

                group['params'][p_idx] = _addcdiv(p, -clr, exp_avg, exp_inf)