optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py (581 lines of code) (raw):

# coding=utf-8 # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from math import ceil from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from diffusers import ( AutoencoderKL, ControlNetModel, MultiControlNetModel, StableDiffusionControlNetPipeline, UNet2DConditionModel, ) from diffusers.image_processor import PipelineImageInput from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import deprecate from diffusers.utils.torch_utils import is_compiled_module from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from optimum.utils import logging from ....transformers.gaudi_configuration import GaudiConfig from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment from ..pipeline_utils import GaudiDiffusionPipeline from ..stable_diffusion.pipeline_stable_diffusion import ( GaudiStableDiffusionPipeline, GaudiStableDiffusionPipelineOutput, retrieve_timesteps, ) logger = logging.get_logger(__name__) class GaudiStableDiffusionControlNetPipeline(GaudiDiffusionPipeline, StableDiffusionControlNetPipeline): """ Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L94 - Generation is performed by batches - Two `mark_step()` were added to add support for lazy mode - Added support for HPU graphs Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`~transformers.CLIPTokenizer`): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. use_habana (bool, defaults to `False`): Whether to use Gaudi (`True`) or CPU (`False`). use_hpu_graphs (bool, defaults to `False`): Whether to use HPU graphs or not. gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`): Gaudi configuration to use. Can be a string to download it from the Hub. Or a previously initialized config can be passed. bf16_full_eval (bool, defaults to `False`): Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory compared to fp32/mixed precision but can harm generated images. sdp_on_bf16 (bool, defaults to `False`): Whether to allow PyTorch to use reduced precision in the SDPA math backend. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, use_habana: bool = False, use_hpu_graphs: bool = False, gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, sdp_on_bf16: bool = False, ): GaudiDiffusionPipeline.__init__( self, use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval, sdp_on_bf16, ) StableDiffusionControlNetPipeline.__init__( self, vae, text_encoder, tokenizer, unet, controlnet, scheduler, safety_checker, feature_extractor, image_encoder, requires_safety_checker, ) self.to(self._device) def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != num_images: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective number" f" of images of {num_images}. Make sure the number of images matches the length of the generators." ) if latents is None: # torch.randn is broken on HPU so running it on CPU rand_device = "cpu" if device.type == "hpu" else device if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(num_images) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, batch_size: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. batch_size (`int`, *optional*, defaults to 1): The number of images in a batch. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.GaudiStableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeine class. profiling_warmup_steps (`int`, *optional*): Number of steps to ignore for profling. profiling_steps (`int`, *optional*): Number of steps to be captured when enabling profiling. Returns: [`~diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.GaudiStableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, image=image, callback_steps=callback_steps, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, controlnet_conditioning_scale=controlnet_conditioning_scale, control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): num_prompts = 1 elif prompt is not None and isinstance(prompt, list): num_prompts = len(prompt) else: num_prompts = prompt_embeds.shape[0] num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes # if do_classifier_free_guidance: # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, device, batch_size * num_images_per_prompt ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] # Nested lists as ControlNet condition if isinstance(image[0], list): # Transpose the nested image list image = [list(t) for t in zip(*image)] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( num_prompts * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( batch_size * num_images_per_prompt ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.3 Split into batches (HPU-specific step) ( latents_batches, text_embeddings_batches, num_dummy_samples, ) = GaudiStableDiffusionPipeline._split_inputs_into_batches( batch_size, latents, prompt_embeds, negative_prompt_embeds, ) outputs = { "images": [], "has_nsfw_concept": [], } t0 = time.time() t1 = t0 self._num_timesteps = len(timesteps) hb_profiler = HabanaProfile( warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=False, ) hb_profiler.start() # 8. Denoising loop throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) use_warmup_inference_steps = ( num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps ) for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations if j == throughput_warmup_steps: t1 = time.time() if use_warmup_inference_steps: t0_inf = time.time() latents_batch = latents_batches[0] latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) text_embeddings_batch = text_embeddings_batches[0] text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order for i in range(num_inference_steps): if use_warmup_inference_steps and i == throughput_warmup_steps: t1_inf = time.time() t1 += t1_inf - t0_inf t = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents_batch control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = text_embeddings_batch.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = text_embeddings_batch if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet_hpu( control_model_input, t, controlnet_prompt_embeds, image, cond_scale, guess_mode, ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat( [torch.zeros_like(mid_block_res_sample), mid_block_res_sample] ) # predict the noise residual noise_pred = self.unet_hpu( latent_model_input, t, text_embeddings_batch, timestep_cond, self.cross_attention_kwargs, down_block_res_samples, mid_block_res_sample, added_cond_kwargs, ) # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_batch = self.scheduler.step( noise_pred, t, latents_batch, **extra_step_kwargs, return_dict=False )[0] if not self.use_hpu_graphs: self.htcore.mark_step() if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents_batch) prompt_embeds = callback_outputs.pop("prompt_embeds", text_embeddings_batches) # negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents_batch) hb_profiler.step() if use_warmup_inference_steps: t1 = warmup_inference_steps_time_adjustment( t1, t1_inf, num_inference_steps, throughput_warmup_steps ) if not output_type == "latent": # 8. Post-processing output_image = self.vae.decode( latents_batch / self.vae.config.scaling_factor, return_dict=False, generator=generator )[0] else: output_image = latents_batch outputs["images"].append(output_image) if not self.use_hpu_graphs: self.htcore.mark_step() hb_profiler.stop() speed_metrics_prefix = "generation" speed_measures = speed_metrics( split=speed_metrics_prefix, start_time=t0, num_samples=num_batches * batch_size if t1 == t0 or use_warmup_inference_steps else (num_batches - throughput_warmup_steps) * batch_size, num_steps=num_batches * batch_size * num_inference_steps, start_time_after_warmup=t1, ) logger.info(f"Speed metrics: {speed_measures}") # Remove dummy generations if needed if num_dummy_samples > 0: outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] # Process generated images for i, image in enumerate(outputs["images"][:]): if i == 0: outputs["images"].clear() if output_type == "latent": has_nsfw_concept = None else: image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if output_type == "pil": outputs["images"] += image else: outputs["images"] += [*image] if has_nsfw_concept is not None: outputs["has_nsfw_concept"] += has_nsfw_concept else: outputs["has_nsfw_concept"] = None # Offload all models self.maybe_free_model_hooks() if not return_dict: return (outputs["images"], outputs["has_nsfw_concept"]) return GaudiStableDiffusionPipelineOutput( images=outputs["images"], nsfw_content_detected=outputs["has_nsfw_concept"], throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], ) @torch.no_grad() def unet_hpu( self, latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs, down_block_additional_residuals, mid_block_additional_residual, added_cond_kwargs, ): if self.use_hpu_graphs: return self.unet_capture_replay( latent_model_input, timestep, encoder_hidden_states, down_block_additional_residuals, mid_block_additional_residual, ) else: return self.unet( latent_model_input, timestep, encoder_hidden_states=encoder_hidden_states, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] @torch.no_grad() def unet_capture_replay( self, latent_model_input, timestep, encoder_hidden_states, down_block_additional_residuals, mid_block_additional_residual, ): inputs = [ latent_model_input, timestep, encoder_hidden_states, down_block_additional_residuals, mid_block_additional_residual, False, ] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() graph.capture_begin() outputs = self.unet( inputs[0], inputs[1], inputs[2], None, None, None, None, None, inputs[3], inputs[4], None, None, inputs[5], )[0] graph.capture_end() graph_inputs = inputs graph_outputs = outputs self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) return outputs # Replay the cached graph with updated inputs self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) cached.graph.replay() self.ht.core.hpu.default_stream().synchronize() return cached.graph_outputs @torch.no_grad() def controlnet_hpu( self, control_model_input, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, guess_mode, ): if self.use_hpu_graphs: return self.controlnet_capture_replay( control_model_input, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, guess_mode, ) else: return self.controlnet( control_model_input, timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_cond, conditioning_scale=conditioning_scale, guess_mode=guess_mode, return_dict=False, ) @torch.no_grad() def controlnet_capture_replay( self, control_model_input, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, guess_mode, ): inputs = [ control_model_input, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, guess_mode, False, ] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() graph.capture_begin() outputs = self.controlnet( inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], None, None, None, None, None, inputs[5], False, ) graph.capture_end() graph_inputs = inputs graph_outputs = outputs self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) return outputs # Replay the cached graph with updated inputs self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) cached.graph.replay() self.ht.core.hpu.default_stream().synchronize() return cached.graph_outputs