optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox.py (278 lines of code) (raw):

# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Callable, Dict, List, Optional, Union import torch from diffusers import CogVideoXPipeline from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d from diffusers.pipelines.cogvideo.pipeline_cogvideox import retrieve_timesteps from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import ( logging, ) from diffusers.utils.torch_utils import randn_tensor from transformers import T5EncoderModel, T5Tokenizer from ....transformers.gaudi_configuration import GaudiConfig from ...models.attention_processor import CogVideoXAttnProcessorGaudi from ...models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3dforwardGaudi, tiled_decode_gaudi from ...models.cogvideox_transformer_3d import cogvideoXTransformerForwardGaudi from ..pipeline_utils import GaudiDiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name setattr(CogVideoXCausalConv3d, "forward", CogVideoXCausalConv3dforwardGaudi) setattr(AutoencoderKLCogVideoX, "tiled_decode", tiled_decode_gaudi) class GaudiCogVideoXPipeline(GaudiDiffusionPipeline, CogVideoXPipeline): r""" Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py#L84 The cogVideoX pipeline for text-to-video generation. """ def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], use_habana: bool = False, use_hpu_graphs: bool = False, gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, sdp_on_bf16: bool = False, ): GaudiDiffusionPipeline.__init__( self, use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval, sdp_on_bf16, ) CogVideoXPipeline.__init__( self, tokenizer, text_encoder, vae, transformer, scheduler, ) self.to(self._device) self.transformer.forward = cogvideoXTransformerForwardGaudi for block in self.transformer.transformer_blocks: block.attn1.set_processor(CogVideoXAttnProcessorGaudi()) from habana_frameworks.torch.hpu import wrap_in_hpu_graph self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) def enable_model_cpu_offload(self, *args, **kwargs): if self.use_habana: raise NotImplementedError("enable_model_cpu_offload() is not implemented for HPU") else: return super().enable_model_cpu_offload(*args, **kwargs) def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, (num_frames - 1) // self.vae_scale_factor_temporal + 1, num_channels_latents, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: # torch.randn is broken on HPU so running it on CPU rand_device = "cpu" if device.type == "hpu" else device rand_device = torch.device(rand_device) latents = randn_tensor(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, use_dynamic_cfg: bool = False, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that needs to be satisfied is that of divisibility mentioned above. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, defaults to `226`): Maximum sequence length in encoded prompt. Must be consistent with `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. Examples: Returns: [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): if num_frames > 49: raise ValueError( "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, negative_prompt, do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) # 5. Prepare latent variables latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) if self.transformer.config.use_rotary_positional_embeddings else None ) # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) outputs = [] import habana_frameworks.torch.core as htcore with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ old_pred_original_sample = None for i, t in enumerate(timesteps): if self.interrupt: continue latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer_hpu( latent_model_input=latent_model_input, prompt_embeds=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, ) noise_pred = noise_pred.float() # perform guidance if use_dynamic_cfg: self._guidance_scale = 1 + guidance_scale * ( (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 if not isinstance(self.scheduler, CogVideoXDPMScheduler): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[ 0 ] else: latents, old_pred_original_sample = self.scheduler.step( noise_pred, old_pred_original_sample, t, timesteps[i - 1] if i > 0 else None, latents, **extra_step_kwargs, return_dict=False, ) latents = latents.to(prompt_embeds.dtype) if not self.use_hpu_graphs: htcore.mark_step() # call the callback, if provided if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not self.use_hpu_graphs: htcore.mark_step() if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return CogVideoXPipelineOutput(frames=video) @torch.no_grad() def transformer_hpu(self, latent_model_input, prompt_embeds, timestep, image_rotary_emb): if self.use_hpu_graphs: return self.capture_replay(latent_model_input, prompt_embeds, timestep, image_rotary_emb) else: return self.transformer( self.transformer, hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] @torch.no_grad() def capture_replay(self, latent_model_input, prompt_embeds, timestep, image_rotary_emb): inputs = [latent_model_input.clone(), prompt_embeds.clone(), timestep.clone(), image_rotary_emb, False] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() graph.capture_begin() outputs = self.transformer( self.transformer, hidden_states=inputs[0], encoder_hidden_states=inputs[1], timestep=inputs[2], image_rotary_emb=inputs[3], return_dict=inputs[4], )[0] graph.capture_end() graph_inputs = inputs graph_outputs = outputs self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) return outputs # Replay the cached graph with updated inputs self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) cached.graph.replay() self.ht.core.hpu.default_stream().synchronize() return cached.graph_outputs