def _update()

in higher/optim.py [0:0]


    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:

        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            amsgrad = group['amsgrad']
            beta1, beta2 = group['betas']
            weight_decay = group['weight_decay']

            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):

                if g is None:
                    continue

                state = self.state[group_idx][p_idx]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = _torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = _torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. mov. avg. of sq. grad. vals
                        state['max_exp_avg_sq'] = _torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']

                state['step'] += 1
                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']

                if weight_decay != 0:
                    g = g + (weight_decay * p)

                # Decay the first and second moment running average coefficient
                state['exp_avg'] = exp_avg = (exp_avg * beta1) + (1 - beta1) * g
                state['exp_avg_sq'] = exp_avg_sq = (
                    (exp_avg_sq * beta2) + (1 - beta2) * g * g
                )

                # Deal with stability issues
                mask = exp_avg_sq == 0.
                _maybe_mask(exp_avg_sq, mask)

                if amsgrad:
                    # Maintains the max of all 2nd moment running avg. till now
                    state['max_exp_avg_sq'] = max_exp_avg_sq = _torch.max(
                        max_exp_avg_sq, exp_avg_sq
                    )
                    # Use the max. for normalizing running avg. of gradient
                    denom = _add(
                        max_exp_avg_sq.sqrt() / _math.sqrt(bias_correction2),
                        group['eps']
                    )
                else:
                    denom = _add(
                        exp_avg_sq.sqrt() / _math.sqrt(bias_correction2),
                        group['eps']
                    )

                step_size = group['lr'] / bias_correction1

                group['params'][p_idx] = _addcdiv(
                    p, -step_size, exp_avg, denom
                )