def get_named_beta_schedule()

in shap_e/diffusion/gaussian_diffusion.py [0:0]


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float):
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        return get_beta_schedule(
            "linear",
            beta_start=scale * 0.0001,
            beta_end=scale * 0.02,
            num_diffusion_timesteps=num_diffusion_timesteps,
        )
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    elif schedule_name == "inv_parabola":
        exponent = extra_args.get("power", 2.0)
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: 1 - t**exponent,
        )
    elif schedule_name == "translated_parabola":
        exponent = extra_args.get("power", 2.0)
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: (1 - t) ** exponent,
        )
    elif schedule_name == "exp":
        coefficient = extra_args.get("coefficient", -12.0)
        return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient))
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")