optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py (475 lines of code) (raw):

# coding=utf-8 # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import time from math import ceil from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from diffusers.image_processor import PipelineImageInput from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( StableDiffusionImg2ImgPipeline, retrieve_latents, ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import deprecate from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from optimum.utils import logging from ....transformers.gaudi_configuration import GaudiConfig from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment from ..pipeline_utils import GaudiDiffusionPipeline logger = logging.get_logger(__name__) def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device="cpu", **kwargs) timesteps = scheduler.timesteps.to(device) num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device="cpu", **kwargs) timesteps = scheduler.timesteps.to(device) # Handles the case where the scheduler cannot implement reset_timestep_dependent_params() # Example: UniPCMultiStepScheduler used for inference in ControlNet training as it has non-linear accesses to timestep dependent parameter: sigma. if hasattr(scheduler, "reset_timestep_dependent_params") and callable(scheduler.reset_timestep_dependent_params): scheduler.reset_timestep_dependent_params() return timesteps, num_inference_steps class GaudiStableDiffusionImg2ImgPipeline(GaudiDiffusionPipeline, StableDiffusionImg2ImgPipeline): """ Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L161 Changes: 1. Use CPU to generate random tensor Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`~transformers.CLIPTokenizer`): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): Pre-trained CLIP vision model used to obtain image features. use_habana (bool, defaults to `False`): Whether to use Gaudi (`True`) or CPU (`False`). use_hpu_graphs (bool, defaults to `False`): Whether to use HPU graphs or not. gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`): Gaudi configuration to use. Can be a string to download it from the Hub. Or a previously initialized config can be passed. bf16_full_eval (bool, defaults to `False`): Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory compared to fp32/mixed precision but can harm generated images. sdp_on_bf16 (bool, defaults to `False`): Whether to allow PyTorch to use reduced precision in the SDPA math backend. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, use_habana: bool = False, use_hpu_graphs: bool = False, gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, sdp_on_bf16: bool = False, ): GaudiDiffusionPipeline.__init__( self, use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval, sdp_on_bf16, ) StableDiffusionImg2ImgPipeline.__init__( self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, image_encoder, requires_safety_checker, ) self.to(self._device) # Copied from ./pipeline_stable_diffusion.py @classmethod def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, negative_prompt_embeds): # Use torch.split to generate num_batches batches of size batch_size latents_batches = list(torch.split(latents, batch_size)) prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size)) if negative_prompt_embeds is not None: negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size)) # If the last batch has less samples than batch_size, pad it with dummy samples num_dummy_samples = 0 if latents_batches[-1].shape[0] < batch_size: num_dummy_samples = batch_size - latents_batches[-1].shape[0] # Pad latents_batches sequence_to_stack = (latents_batches[-1],) + tuple( torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples) ) latents_batches[-1] = torch.vstack(sequence_to_stack) # Pad prompt_embeds_batches sequence_to_stack = (prompt_embeds_batches[-1],) + tuple( torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) ) prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) # Pad negative_prompt_embeds_batches if necessary if negative_prompt_embeds is not None: sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple( torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) ) negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) # Stack batches in the same tensor latents_batches = torch.stack(latents_batches) if negative_prompt_embeds is not None: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate( zip(negative_prompt_embeds_batches, prompt_embeds_batches[:]) ): prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch]) prompt_embeds_batches = torch.stack(prompt_embeds_batches) return latents_batches, prompt_embeds_batches, num_dummy_samples def prepare_latents(self, image, timestep, num_prompts, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = num_prompts * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) # Reuse first generator for noise if isinstance(generator, list): generator = generator[0] shape = init_latents.shape rand_device = "cpu" if device.type == "hpu" else device noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) # HPU Patch noise = noise.to(device) # HPU Patch # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, timesteps: List[int] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, batch_size: int = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. batch_size (`int`, *optional*, defaults to 1): The number of images in a batch. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. profiling_warmup_steps (`int`, *optional*): Number of steps to ignore for profling. profiling_steps (`int`, *optional*): Number of steps to be captured when enabling profiling. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): num_prompts = 1 elif prompt is not None and isinstance(prompt, list): num_prompts = len(prompt) else: num_prompts = prompt_embeds.shape[0] num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) logger.info( f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." ) if num_batches < 3: logger.warning("The first two iterations are slower so it is recommended to feed more batches.") device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) if ip_adapter_image is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, device, batch_size * num_images_per_prompt ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. set timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) latent_timestep = timesteps[:1].repeat(num_prompts * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, num_prompts, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( batch_size * num_images_per_prompt ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 8. Split into batches (HPU-specific step) latents_batches, text_embeddings_batches, num_dummy_samples = self._split_inputs_into_batches( batch_size, latents, prompt_embeds, negative_prompt_embeds, ) outputs = { "images": [], "has_nsfw_concept": [], } t0 = time.time() t1 = t0 hb_profiler = HabanaProfile( warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=False, ) hb_profiler.start() # 9. Denoising loop throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) use_warmup_inference_steps = num_batches < throughput_warmup_steps < num_inference_steps num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations if j == throughput_warmup_steps: t1 = time.time() if use_warmup_inference_steps: t0_inf = time.time() latents_batch = latents_batches[0] latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) text_embeddings_batch = text_embeddings_batches[0] text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) for i in range(num_inference_steps): # HPU Patch if use_warmup_inference_steps and i == throughput_warmup_steps: t1_inf = time.time() t1 += t1_inf - t0_inf t = timesteps[0] # HPU Patch timesteps = torch.roll(timesteps, shifts=-1, dims=0) # HPU Patch if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet_hpu( latent_model_input, t, encoder_hidden_states=text_embeddings_batch, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ) # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_batch = self.scheduler.step( noise_pred, t, latents_batch, **extra_step_kwargs, return_dict=False )[0] # HPU Patch if not self.use_hpu_graphs: self.htcore.mark_step() if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents_batch = callback_outputs.pop("latents", latents_batch) text_embeddings_batch = callback_outputs.pop("prompt_embeds", text_embeddings_batch) # negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) hb_profiler.step() if use_warmup_inference_steps: t1 = warmup_inference_steps_time_adjustment( t1, t1_inf, num_inference_steps, throughput_warmup_steps ) if not output_type == "latent": image = self.vae.decode( latents_batch / self.vae.config.scaling_factor, return_dict=False, generator=generator )[0] else: image = latents_batch outputs["images"].append(image) hb_profiler.stop() speed_metrics_prefix = "generation" speed_measures = speed_metrics( split=speed_metrics_prefix, start_time=t0, num_samples=num_batches * batch_size if t1 == t0 or use_warmup_inference_steps else (num_batches - throughput_warmup_steps) * batch_size, num_steps=num_batches, start_time_after_warmup=t1, ) logger.info(f"Speed metrics: {speed_measures}") # Remove dummy generations if needed if num_dummy_samples > 0: outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] # Process generated images for i, image in enumerate(outputs["images"][:]): if i == 0: outputs["images"].clear() if output_type == "latent": has_nsfw_concept = None else: image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if output_type == "pil" and isinstance(image, list): outputs["images"] += image elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): if len(outputs["images"]) == 0: outputs["images"] = image else: outputs["images"] = np.concatenate((outputs["images"], image), axis=0) else: if len(outputs["images"]) == 0: outputs["images"] = image else: outputs["images"] = torch.cat((outputs["images"], image), 0) if has_nsfw_concept is not None: outputs["has_nsfw_concept"] += has_nsfw_concept else: outputs["has_nsfw_concept"] = None # Offload all models self.maybe_free_model_hooks() if not return_dict: return (outputs["images"], outputs["has_nsfw_concept"]) return StableDiffusionPipelineOutput( images=outputs["images"], nsfw_content_detected=outputs["has_nsfw_concept"] ) @torch.no_grad() def unet_hpu( self, latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs, added_cond_kwargs, ): if self.use_hpu_graphs: return self.capture_replay( latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs, added_cond_kwargs, ) else: return self.unet( latent_model_input, timestep, encoder_hidden_states=encoder_hidden_states, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] @torch.no_grad() def capture_replay( self, latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs, added_cond_kwargs, ): inputs = [ latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs, added_cond_kwargs, False, ] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() graph.capture_begin() outputs = self.unet( inputs[0], inputs[1], encoder_hidden_states=inputs[2], timestep_cond=inputs[3], cross_attention_kwargs=inputs[4], added_cond_kwargs=inputs[5], return_dict=inputs[6], )[0] graph.capture_end() graph_inputs = inputs graph_outputs = outputs self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) return outputs # Replay the cached graph with updated inputs self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) cached.graph.replay() self.ht.core.hpu.default_stream().synchronize() return cached.graph_outputs