optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py (475 lines of code) (raw):
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import time
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from diffusers.image_processor import PipelineImageInput
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
retrieve_latents,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import deprecate
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from optimum.utils import logging
from ....transformers.gaudi_configuration import GaudiConfig
from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
from ..pipeline_utils import GaudiDiffusionPipeline
logger = logging.get_logger(__name__)
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device="cpu", **kwargs)
timesteps = scheduler.timesteps.to(device)
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device="cpu", **kwargs)
timesteps = scheduler.timesteps.to(device)
# Handles the case where the scheduler cannot implement reset_timestep_dependent_params()
# Example: UniPCMultiStepScheduler used for inference in ControlNet training as it has non-linear accesses to timestep dependent parameter: sigma.
if hasattr(scheduler, "reset_timestep_dependent_params") and callable(scheduler.reset_timestep_dependent_params):
scheduler.reset_timestep_dependent_params()
return timesteps, num_inference_steps
class GaudiStableDiffusionImg2ImgPipeline(GaudiDiffusionPipeline, StableDiffusionImg2ImgPipeline):
"""
Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L161
Changes:
1. Use CPU to generate random tensor
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (`~transformers.CLIPTokenizer`):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
Pre-trained CLIP vision model used to obtain image features.
use_habana (bool, defaults to `False`):
Whether to use Gaudi (`True`) or CPU (`False`).
use_hpu_graphs (bool, defaults to `False`):
Whether to use HPU graphs or not.
gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
Gaudi configuration to use. Can be a string to download it from the Hub.
Or a previously initialized config can be passed.
bf16_full_eval (bool, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit.
This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
sdp_on_bf16 (bool, defaults to `False`):
Whether to allow PyTorch to use reduced precision in the SDPA math backend.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
use_habana: bool = False,
use_hpu_graphs: bool = False,
gaudi_config: Union[str, GaudiConfig] = None,
bf16_full_eval: bool = False,
sdp_on_bf16: bool = False,
):
GaudiDiffusionPipeline.__init__(
self,
use_habana,
use_hpu_graphs,
gaudi_config,
bf16_full_eval,
sdp_on_bf16,
)
StableDiffusionImg2ImgPipeline.__init__(
self,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
image_encoder,
requires_safety_checker,
)
self.to(self._device)
# Copied from ./pipeline_stable_diffusion.py
@classmethod
def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, negative_prompt_embeds):
# Use torch.split to generate num_batches batches of size batch_size
latents_batches = list(torch.split(latents, batch_size))
prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
if negative_prompt_embeds is not None:
negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
# If the last batch has less samples than batch_size, pad it with dummy samples
num_dummy_samples = 0
if latents_batches[-1].shape[0] < batch_size:
num_dummy_samples = batch_size - latents_batches[-1].shape[0]
# Pad latents_batches
sequence_to_stack = (latents_batches[-1],) + tuple(
torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
latents_batches[-1] = torch.vstack(sequence_to_stack)
# Pad prompt_embeds_batches
sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
# Pad negative_prompt_embeds_batches if necessary
if negative_prompt_embeds is not None:
sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
# Stack batches in the same tensor
latents_batches = torch.stack(latents_batches)
if negative_prompt_embeds is not None:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate(
zip(negative_prompt_embeds_batches, prompt_embeds_batches[:])
):
prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch])
prompt_embeds_batches = torch.stack(prompt_embeds_batches)
return latents_batches, prompt_embeds_batches, num_dummy_samples
def prepare_latents(self, image, timestep, num_prompts, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = num_prompts * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents], dim=0)
# Reuse first generator for noise
if isinstance(generator, list):
generator = generator[0]
shape = init_latents.shape
rand_device = "cpu" if device.type == "hpu" else device
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) # HPU Patch
noise = noise.to(device) # HPU Patch
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
batch_size: int = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
profiling_warmup_steps: Optional[int] = 0,
profiling_steps: Optional[int] = 0,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
batch_size (`int`, *optional*, defaults to 1):
The number of images in a batch.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
profiling_warmup_steps (`int`, *optional*):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*):
Number of steps to be captured when enabling profiling.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast):
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
num_prompts = 1
elif prompt is not None and isinstance(prompt, list):
num_prompts = len(prompt)
else:
num_prompts = prompt_embeds.shape[0]
num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size)
logger.info(
f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt,"
f" {batch_size} sample(s) per batch, {num_batches} total batch(es)."
)
if num_batches < 3:
logger.warning("The first two iterations are slower so it is recommended to feed more batches.")
device = self._execution_device
# 3. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
if ip_adapter_image is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt
)
# 4. Preprocess image
image = self.image_processor.preprocess(image)
# 5. set timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
latent_timestep = timesteps[:1].repeat(num_prompts * num_images_per_prompt)
# 6. Prepare latent variables
latents = self.prepare_latents(
image,
latent_timestep,
num_prompts,
num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 7.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
batch_size * num_images_per_prompt
)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 8. Split into batches (HPU-specific step)
latents_batches, text_embeddings_batches, num_dummy_samples = self._split_inputs_into_batches(
batch_size,
latents,
prompt_embeds,
negative_prompt_embeds,
)
outputs = {
"images": [],
"has_nsfw_concept": [],
}
t0 = time.time()
t1 = t0
hb_profiler = HabanaProfile(
warmup=profiling_warmup_steps,
active=profiling_steps,
record_shapes=False,
)
hb_profiler.start()
# 9. Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = num_batches < throughput_warmup_steps < num_inference_steps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == throughput_warmup_steps:
t1 = time.time()
if use_warmup_inference_steps:
t0_inf = time.time()
latents_batch = latents_batches[0]
latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
text_embeddings_batch = text_embeddings_batches[0]
text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0)
for i in range(num_inference_steps): # HPU Patch
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
t = timesteps[0] # HPU Patch
timesteps = torch.roll(timesteps, shifts=-1, dims=0) # HPU Patch
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet_hpu(
latent_model_input,
t,
encoder_hidden_states=text_embeddings_batch,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_batch = self.scheduler.step(
noise_pred, t, latents_batch, **extra_step_kwargs, return_dict=False
)[0]
# HPU Patch
if not self.use_hpu_graphs:
self.htcore.mark_step()
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents_batch = callback_outputs.pop("latents", latents_batch)
text_embeddings_batch = callback_outputs.pop("prompt_embeds", text_embeddings_batch)
# negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
hb_profiler.step()
if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
t1, t1_inf, num_inference_steps, throughput_warmup_steps
)
if not output_type == "latent":
image = self.vae.decode(
latents_batch / self.vae.config.scaling_factor, return_dict=False, generator=generator
)[0]
else:
image = latents_batch
outputs["images"].append(image)
hb_profiler.stop()
speed_metrics_prefix = "generation"
speed_measures = speed_metrics(
split=speed_metrics_prefix,
start_time=t0,
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
# Remove dummy generations if needed
if num_dummy_samples > 0:
outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples]
# Process generated images
for i, image in enumerate(outputs["images"][:]):
if i == 0:
outputs["images"].clear()
if output_type == "latent":
has_nsfw_concept = None
else:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type == "pil" and isinstance(image, list):
outputs["images"] += image
elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray):
if len(outputs["images"]) == 0:
outputs["images"] = image
else:
outputs["images"] = np.concatenate((outputs["images"], image), axis=0)
else:
if len(outputs["images"]) == 0:
outputs["images"] = image
else:
outputs["images"] = torch.cat((outputs["images"], image), 0)
if has_nsfw_concept is not None:
outputs["has_nsfw_concept"] += has_nsfw_concept
else:
outputs["has_nsfw_concept"] = None
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (outputs["images"], outputs["has_nsfw_concept"])
return StableDiffusionPipelineOutput(
images=outputs["images"], nsfw_content_detected=outputs["has_nsfw_concept"]
)
@torch.no_grad()
def unet_hpu(
self,
latent_model_input,
timestep,
encoder_hidden_states,
timestep_cond,
cross_attention_kwargs,
added_cond_kwargs,
):
if self.use_hpu_graphs:
return self.capture_replay(
latent_model_input,
timestep,
encoder_hidden_states,
timestep_cond,
cross_attention_kwargs,
added_cond_kwargs,
)
else:
return self.unet(
latent_model_input,
timestep,
encoder_hidden_states=encoder_hidden_states,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
@torch.no_grad()
def capture_replay(
self,
latent_model_input,
timestep,
encoder_hidden_states,
timestep_cond,
cross_attention_kwargs,
added_cond_kwargs,
):
inputs = [
latent_model_input,
timestep,
encoder_hidden_states,
timestep_cond,
cross_attention_kwargs,
added_cond_kwargs,
False,
]
h = self.ht.hpu.graphs.input_hash(inputs)
cached = self.cache.get(h)
if cached is None:
# Capture the graph and cache it
with self.ht.hpu.stream(self.hpu_stream):
graph = self.ht.hpu.HPUGraph()
graph.capture_begin()
outputs = self.unet(
inputs[0],
inputs[1],
encoder_hidden_states=inputs[2],
timestep_cond=inputs[3],
cross_attention_kwargs=inputs[4],
added_cond_kwargs=inputs[5],
return_dict=inputs[6],
)[0]
graph.capture_end()
graph_inputs = inputs
graph_outputs = outputs
self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph)
return outputs
# Replay the cached graph with updated inputs
self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs)
cached.graph.replay()
self.ht.core.hpu.default_stream().synchronize()
return cached.graph_outputs