def step()

in jukebox/utils/fp16.py [0:0]


    def step(self, closure=None, scale=1.0):
        """Performs a single optimization step. Scales gradients down by scale
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            bias_correction = 1 if group["bias_correction"] else 0

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]

                if p.data.dtype == torch.float16:
                    exp_avg, exp_avg_sq = (
                        state["exp_avg"].float() * state["scale_exp_avg"],
                        state["exp_avg_sq"].float() * state["scale_exp_avg_sq"],
                    )
                else:
                    exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                out_p = torch.tensor([], dtype=torch.float)
                fused_adam_step(
                    p.data,
                    out_p,
                    exp_avg,
                    exp_avg_sq,
                    grad,
                    group["lr"],
                    beta1,
                    beta2,
                    group["eps"],
                    scale,
                    state["step"],
                    self.eps_mode,
                    bias_correction,
                    group["weight_decay"],
                )

                if p.data.dtype == torch.float16:
                    state["scale_exp_avg"] = (
                        1e-8 + float(torch.norm(exp_avg, float("inf"))) / self.FLOAT16_MAX
                    )
                    state["scale_exp_avg_sq"] = (
                        1e-8 + float(torch.norm(exp_avg_sq, float("inf"))) / self.FLOAT16_MAX
                    )
                    state["exp_avg"] = (exp_avg / state["scale_exp_avg"]).half()
                    state["exp_avg_sq"] = (exp_avg_sq / state["scale_exp_avg_sq"]).half()

        return loss