def step()

in fairseq/optim/fused_adam.py [0:0]


    def step(self, closure=None, grads=None, scale=1.0, grad_norms=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        loss = None
        if closure is not None:
            loss = closure()

        if grads is None:
            grads_group = [None] * len(self.param_groups)
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        if grad_norms is None:
            grad_norms = [None] * len(self.param_groups)

        for group, grads_this_group, grad_norm in zip(
            self.param_groups, grads_group, grad_norms
        ):
            if grads_this_group is None:
                grads_this_group = [None] * len(group["params"])

            # compute combined scale factor for this group
            combined_scale = scale
            if group.get("max_grad_norm", 0) > 0:
                # norm is in fact norm*scale
                clip = ((grad_norm / scale) + 1e-6) / group["max_grad_norm"]
                if clip > 1:
                    combined_scale = clip * scale

            bias_correction = 1 if group.get("bias_correction", 1) else 0

            for p, grad in zip(group["params"], grads_this_group):
                # note: p.grad should not ever be set for correct
                # operation of mixed precision optimizer that sometimes
                # sends None gradients
                if p.grad is None and grad is None:
                    continue
                if grad is None:
                    grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        "FusedAdam does not support sparse gradients, "
                        "please consider SparseAdam instead"
                    )

                if p.device.type == "cpu":
                    p_data_fp32 = p.data.cuda(non_blocking=True).float()
                    out_p = torch.tensor([], dtype=torch.float)
                else:
                    p_data_fp32 = p.data.float()
                    out_p = p.data

                state = self.state[p]

                # State initialization
                dtype = torch.float16 if self.use_fp16_stats else p_data_fp32.dtype
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p_data_fp32, dtype=dtype)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p_data_fp32, dtype=dtype)
                    if self.use_fp16_stats:
                        state["exp_avg_scale"] = 1.0
                        state["exp_avg_sq_scale"] = 1.0
                else:
                    device = p_data_fp32.device
                    state["exp_avg"] = state["exp_avg"].to(device, dtype)
                    state["exp_avg_sq"] = state["exp_avg_sq"].to(device, dtype)

                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]
                if self.use_fp16_stats:
                    assert exp_avg.dtype == torch.float16
                    exp_avg = exp_avg.float() * state["exp_avg_scale"]
                    exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                with torch.cuda.device(p_data_fp32.device):
                    fused_adam_cuda.adam(
                        p_data_fp32,
                        out_p,
                        exp_avg,
                        exp_avg_sq,
                        grad,
                        group["lr"],
                        beta1,
                        beta2,
                        group["eps"],
                        combined_scale,
                        state["step"],
                        self.eps_mode,
                        bias_correction,
                        group["weight_decay"],
                    )

                if p.device.type == "cpu":
                    p.data.copy_(p_data_fp32, non_blocking=True)

                if self.use_fp16_stats:

                    def inf_norm(t):
                        return torch.norm(t, float("inf"))

                    # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py
                    state["exp_avg_scale"], state["exp_avg_sq_scale"] = (
                        1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX,
                        1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX,
                    )
                    state["exp_avg"], state["exp_avg_sq"] = (
                        (exp_avg / state["exp_avg_scale"]).half(),
                        (exp_avg_sq / state["exp_avg_sq_scale"]).half(),
                    )

        return loss