optimum/habana/diffusers/schedulers/scheduling_euler_ancestral_discrete.py (127 lines of code) (raw):

# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import register_to_config from diffusers.schedulers import EulerAncestralDiscreteScheduler from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput from optimum.utils import logging logger = logging.get_logger(__name__) class GaudiEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler): """ Extends [Diffusers' EulerAncestralDiscreteScheduler](https://huggingface.co/docs/diffusers/en/api/schedulers/euler_ancestral) to run optimally on Gaudi: - All time-dependent parameters are generated at the beginning - At each time step, tensors are rolled to update the values of the time-dependent parameters Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable Diffusion. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): super().__init__( num_train_timesteps, beta_start, beta_end, beta_schedule, trained_betas, prediction_type, timestep_spacing, steps_offset, ) self._initial_timestep = None self.reset_timestep_dependent_params() def reset_timestep_dependent_params(self): self.are_timestep_dependent_params_set = False self.sigma_t_list = [] self.sigma_up_t_list = [] self.sigma_down_t_list = [] def get_params(self, timestep: Union[float, torch.FloatTensor]): """ Initialize the time-dependent parameters, and retrieve the time-dependent parameters at each timestep. The tensors are rolled in a separate function at the end of the scheduler step in case parameters are retrieved multiple times in a timestep, e.g., when scaling model inputs and in the scheduler step. Args: timestep (`float`): The current discrete timestep in the diffusion chain. Optionally used to initialize parameters in cases which start in the middle of the denoising schedule (e.g. for image-to-image) """ if self.step_index is None: self._init_step_index(timestep) if not self.are_timestep_dependent_params_set: sigmas_from = self.sigmas[self.step_index : -1] sigmas_to = self.sigmas[(self.step_index + 1) :] for sigma_from, sigma_to in zip(sigmas_from, sigmas_to): sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 self.sigma_t_list.append(sigma_from) self.sigma_up_t_list.append(sigma_up) self.sigma_down_t_list.append(sigma_down) self.sigma_t_list = torch.stack(self.sigma_t_list) self.sigma_up_t_list = torch.stack(self.sigma_up_t_list) self.sigma_down_t_list = torch.stack(self.sigma_down_t_list) self.are_timestep_dependent_params_set = True sigma = self.sigma_t_list[0] sigma_up = self.sigma_up_t_list[0] sigma_down = self.sigma_down_t_list[0] return sigma, sigma_up, sigma_down def roll_params(self): """ Roll tensors to update the values of the time-dependent parameters at each timestep. """ if self.are_timestep_dependent_params_set: self.sigma_t_list = torch.roll(self.sigma_t_list, shifts=-1, dims=0) self.sigma_up_t_list = torch.roll(self.sigma_up_t_list, shifts=-1, dims=0) self.sigma_down_t_list = torch.roll(self.sigma_down_t_list, shifts=-1, dims=0) else: raise ValueError("Time-dependent parameters should be set first.") return def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.FloatTensor`: A scaled input sample. """ sigma, _, _ = self.get_params(timestep) sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma, sigma_up, sigma_down = self.get_params(timestep) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma dt = sigma_down - sigma prev_sample = sample + derivative * dt device = model_output.device # torch.randn is broken on HPU so running it on CPU noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator) if device.type == "hpu": noise = noise.to(device) prev_sample = prev_sample + noise * sigma_up # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 self.roll_params() if not return_dict: return (prev_sample,) return EulerAncestralDiscreteSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample )