optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py (314 lines of code) (raw):

# coding=utf-8 # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from dataclasses import dataclass from math import ceil from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from diffusers.image_processor import PipelineImageInput from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines import StableDiffusionLDM3DPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import BaseOutput from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from optimum.utils import logging from ....transformers.gaudi_configuration import GaudiConfig from ....utils import speed_metrics, warmup_inference_steps_time_adjustment from ..pipeline_utils import GaudiDiffusionPipeline from .pipeline_stable_diffusion import GaudiStableDiffusionPipeline logger = logging.get_logger(__name__) @dataclass class GaudiStableDiffusionLDM3DPipelineOutput(BaseOutput): rgb: Union[List[PIL.Image.Image], np.ndarray] depth: Union[List[PIL.Image.Image], np.ndarray] throughput: float nsfw_content_detected: Optional[List[bool]] class GaudiStableDiffusionLDM3DPipeline(GaudiDiffusionPipeline, StableDiffusionLDM3DPipeline): """ Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py#L84 - Generation is performed by batches - Two `mark_step()` were added to add support for lazy mode - Added support for HPU graphs - Adjusted original Stable Diffusion to match with the LDM3D implementation (input and output being different) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. use_habana (bool, defaults to `False`): Whether to use Gaudi (`True`) or CPU (`False`). use_hpu_graphs (bool, defaults to `False`): Whether to use HPU graphs or not. gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`): Gaudi configuration to use. Can be a string to download it from the Hub. Or a previously initialized config can be passed. bf16_full_eval (bool, defaults to `False`): Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory compared to fp32/mixed precision but can harm generated images. sdp_on_bf16 (bool, defaults to `False`): Whether to allow PyTorch to use reduced precision in the SDPA math backend. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection], requires_safety_checker: bool = True, use_habana: bool = False, use_hpu_graphs: bool = False, gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, sdp_on_bf16: bool = False, ): GaudiDiffusionPipeline.__init__( self, use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval, sdp_on_bf16, ) # Workaround for Synapse 1.11 for full bf16 if bf16_full_eval: unet.conv_in.float() StableDiffusionLDM3DPipeline.__init__( self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, image_encoder, requires_safety_checker, ) self.to(self._device) def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != num_images: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective number" f" of images of {num_images}. Make sure the number of images matches the length of the generators." ) if latents is None: # torch.randn is broken on HPU so running it on CPU rand_device = "cpu" if device.type == "hpu" else device if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(num_images) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, batch_size: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Returns: [`~diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.GaudiStableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): num_prompts = 1 elif prompt is not None and isinstance(prompt, list): num_prompts = len(prompt) else: num_prompts = prompt_embeds.shape[0] num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) logger.info( f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." ) if num_batches < 3: logger.warning("The first two iterations are slower so it is recommended to feed more batches.") device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if ip_adapter_image is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, device, batch_size * num_images_per_prompt ) # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device="cpu") timesteps = self.scheduler.timesteps.to(device) self.scheduler.reset_timestep_dependent_params() # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( num_prompts * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 7. Split into batches (HPU-specific step) ( latents_batches, text_embeddings_batches, num_dummy_samples, ) = GaudiStableDiffusionPipeline._split_inputs_into_batches( batch_size, latents, prompt_embeds, negative_prompt_embeds, ) outputs = { "images": [], "depths": [], "has_nsfw_concept": [], } t0 = time.time() t1 = t0 # 8. Denoising loop throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) use_warmup_inference_steps = ( num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps ) for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations if j == throughput_warmup_steps: t1 = time.time() if use_warmup_inference_steps: t0_inf = time.time() latents_batch = latents_batches[0] latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) text_embeddings_batch = text_embeddings_batches[0] text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) for i in range(len(timesteps)): if use_warmup_inference_steps and i == throughput_warmup_steps: t1_inf = time.time() t1 += t1_inf - t0_inf timestep = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch ) # latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) # predict the noise residual noise_pred = self.unet_hpu( latent_model_input, timestep, text_embeddings_batch, cross_attention_kwargs, added_cond_kwargs, ) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_batch = self.scheduler.step( noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False )[0] if not self.use_hpu_graphs: self.htcore.mark_step() # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, timestep, latents_batch) if use_warmup_inference_steps: t1 = warmup_inference_steps_time_adjustment( t1, t1_inf, num_inference_steps, throughput_warmup_steps ) if not output_type == "latent": # 8. Post-processing image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents_batch outputs["images"].append(image) if not self.use_hpu_graphs: self.htcore.mark_step() speed_metrics_prefix = "generation" speed_measures = speed_metrics( split=speed_metrics_prefix, start_time=t0, num_samples=num_batches * batch_size if t1 == t0 or use_warmup_inference_steps else (num_batches - throughput_warmup_steps) * batch_size, num_steps=num_batches * batch_size * num_inference_steps, start_time_after_warmup=t1, ) logger.info(f"Speed metrics: {speed_measures}") # Remove dummy generations if needed if num_dummy_samples > 0: outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] # Process generated images for i, image in enumerate(outputs["images"][:]): if i == 0: outputs["images"].clear() if output_type == "latent": has_nsfw_concept = None else: image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] rgb, depth = self.image_processor.postprocess( image, output_type=output_type, do_denormalize=do_denormalize ) if output_type == "pil": outputs["images"] += rgb outputs["depths"] += depth else: outputs["images"] += [*rgb] outputs["depths"] += [*depth] if has_nsfw_concept is not None: outputs["has_nsfw_concept"] += has_nsfw_concept else: outputs["has_nsfw_concept"] = None # Offload all models self.maybe_free_model_hooks() if not return_dict: return ((rgb, depth), has_nsfw_concept) return GaudiStableDiffusionLDM3DPipelineOutput( rgb=outputs["images"], depth=outputs["depths"], nsfw_content_detected=has_nsfw_concept, throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], ) @torch.no_grad() def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, added_cond_kwargs): if self.use_hpu_graphs: return self.capture_replay(latent_model_input, timestep, encoder_hidden_states) else: return self.unet( latent_model_input, timestep, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] @torch.no_grad() def capture_replay(self, latent_model_input, timestep, encoder_hidden_states): inputs = [latent_model_input, timestep, encoder_hidden_states, False] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() graph.capture_begin() outputs = self.unet(inputs[0], inputs[1], inputs[2], inputs[3])[0] graph.capture_end() graph_inputs = inputs graph_outputs = outputs self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) return outputs # Replay the cached graph with updated inputs self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) cached.graph.replay() self.ht.core.hpu.default_stream().synchronize() return cached.graph_outputs