optimum/neuron/pipelines/diffusers/pipeline_controlnet.py (268 lines of code) (raw):

# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Override some diffusers API for NeuronStableDiffusionControlNetPipeline""" import logging from typing import Any, Callable, Dict, List, Optional, Union import torch from diffusers import ControlNetModel from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import PipelineImageInput from diffusers.pipelines.controlnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps logger = logging.getLogger(__name__) class NeuronStableDiffusionControlNetPipelineMixin: def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: str = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`Optional["PipelineImageInput"]`, defaults to `None`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. num_inference_steps (`int`, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`Optional[List[int]]`, defaults to `None`): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`Optional[List[int]]`, defaults to `None`): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, defaults to 1): The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, it will be overridden by the static batch size of neuron (except for dynamic batching). eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`Optional[torch.Tensor]`, defaults to `None`): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`Optional[PipelineImageInput]`, defaults to `None`): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`Optional[List[torch.Tensor]]`, defaults to `None`): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, defaults to `True`): Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`Union[float, List[float]]`, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`Union[float, List[float]]`, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`Union[float, List[float]]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`Optional[int]`, defaults to `None`): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]]`, defaults to `None`): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Returns: [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, image=image, callback_steps=None, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ip_adapter_image=ip_adapter_image, ip_adapter_image_embeds=ip_adapter_image_embeds, controlnet_conditioning_scale=controlnet_conditioning_scale, control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.config[0].global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # TODO: support guess mode of ControlNet if guess_mode: logger.info("Disabling the guess mode as this is not supported yet.") guess_mode = False # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, None, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # TODO: support ip adapter if ip_adapter_image is not None or ip_adapter_image_embeds is not None: logger.info( "IP adapter is not supported yet, `ip_adapter_image` and `ip_adapter_image_embeds` will be ignored." ) # 4. Prepare image height = self.vae.config.neuron["static_height"] * self.vae_scale_factor width = self.vae.config.neuron["static_width"] * self.vae_scale_factor if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=None, dtype=None, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] # Nested lists as ControlNet condition if isinstance(image[0], list): # Transpose the nested image list image = [list(t) for t in zip(*image)] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=None, dtype=None, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( scheduler=self.scheduler, num_inference_steps=num_inference_steps, device=None, timesteps=timesteps, sigmas=sigmas, ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, None, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=None, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # TODO: 7.1 Add image embeds for IP-Adapter added_cond_kwargs = None # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] # Duplicate inputs for ddp t = torch.tensor([t] * 2) if self.data_parallel_mode == "unet" else t if isinstance(controlnet, ControlNetModel): cond_scale = ( torch.tensor([cond_scale]).repeat(2) if self.data_parallel_mode == "unet" else torch.tensor(cond_scale) ) else: for i, scale in enumerate(cond_scale): new_scale = ( torch.tensor([scale]).repeat(2) if self.data_parallel_mode == "unet" else torch.tensor(scale) ) cond_scale[i] = new_scale down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # De-Duplicate inputs for ddp t = t[0] if self.data_parallel_mode == "unet" else t # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": image = self.vae.decode( latents / getattr(self.vae.config, "scaling_factor", 0.18215), return_dict=False, generator=generator )[0] image, has_nsfw_concept = self.run_safety_checker(image, None, dtype=prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)