optimum/onnxruntime/modeling_diffusion.py (857 lines of code) (raw):
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import logging
import os
from collections import OrderedDict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Optional, Sequence, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipelines import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import SchedulerMixin
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils.constants import CONFIG_NAME
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from huggingface_hub import HfApi, create_repo
from huggingface_hub.utils import validate_hf_hub_args
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from transformers.file_utils import add_end_docstrings
from transformers.modeling_outputs import ModelOutput
from transformers.utils import http_user_agent
from onnxruntime import InferenceSession, SessionOptions
from ..exporters.onnx import main_export
from ..utils import (
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER,
DIFFUSION_MODEL_UNET_SUBFOLDER,
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
DIFFUSION_PIPELINE_CONFIG_FILE_NAME,
ONNX_WEIGHTS_NAME,
is_diffusers_version,
)
from .base import ORTParentMixin, ORTSessionMixin
from .utils import get_device_for_provider, np_to_pt_generators, prepare_providers_and_provider_options
if is_diffusers_version(">=", "0.25.0"):
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
else:
from diffusers.models.vae import DiagonalGaussianDistribution # type: ignore
logger = logging.getLogger(__name__)
# TODO: support from_pipe()
class ORTDiffusionPipeline(ORTParentMixin, DiffusionPipeline):
config_name = DIFFUSION_PIPELINE_CONFIG_FILE_NAME
task = "auto"
library = "diffusers"
auto_model_class = DiffusionPipeline
def __init__(
self,
*,
# pipeline models
unet_session: Optional["InferenceSession"] = None,
transformer_session: Optional["InferenceSession"] = None,
vae_decoder_session: Optional["InferenceSession"] = None,
vae_encoder_session: Optional["InferenceSession"] = None,
text_encoder_session: Optional["InferenceSession"] = None,
text_encoder_2_session: Optional["InferenceSession"] = None,
text_encoder_3_session: Optional["InferenceSession"] = None,
# pipeline submodels
scheduler: Optional["SchedulerMixin"] = None,
tokenizer: Optional["CLIPTokenizer"] = None,
tokenizer_2: Optional["CLIPTokenizer"] = None,
tokenizer_3: Optional["CLIPTokenizer"] = None,
feature_extractor: Optional["CLIPFeatureExtractor"] = None,
# stable diffusion xl specific arguments
force_zeros_for_empty_prompt: bool = True,
requires_aesthetics_score: bool = False,
add_watermarker: Optional[bool] = None,
# onnxruntime specific arguments
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
):
# We initialize all ort session mixins first
self.unet = ORTUnet(unet_session, self, use_io_binding) if unet_session is not None else None
self.transformer = (
ORTTransformer(transformer_session, self, use_io_binding) if transformer_session is not None else None
)
self.text_encoder = (
ORTTextEncoder(text_encoder_session, self, use_io_binding) if text_encoder_session is not None else None
)
self.text_encoder_2 = (
ORTTextEncoder(text_encoder_2_session, self, use_io_binding)
if text_encoder_2_session is not None
else None
)
self.text_encoder_3 = (
ORTTextEncoder(text_encoder_3_session, self, use_io_binding)
if text_encoder_3_session is not None
else None
)
self.vae_encoder = (
ORTVaeEncoder(vae_encoder_session, self, use_io_binding) if vae_encoder_session is not None else None
)
self.vae_decoder = (
ORTVaeDecoder(vae_decoder_session, self, use_io_binding) if vae_decoder_session is not None else None
)
# We register ort session mixins to the wrapper
super().initialize_ort_attributes(
parts=list(
filter(
None,
{
self.unet,
self.transformer,
self.vae_encoder,
self.vae_decoder,
self.text_encoder,
self.text_encoder_2,
self.text_encoder_3,
},
)
)
)
# We wrap the VAE Encoder & Decoder in a single object for convenience
self.vae = (
ORTVae(self.vae_encoder, self.vae_decoder)
if self.vae_encoder is not None or self.vae_decoder is not None
else None
)
# we allow passing these as torch models for now
self.image_encoder = kwargs.pop("image_encoder", None) # TODO: maybe implement ORTImageEncoder
self.safety_checker = kwargs.pop("safety_checker", None) # TODO: maybe implement ORTSafetyChecker
# We register the submodels to the pipeline
self.scheduler = scheduler
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.tokenizer_3 = tokenizer_3
self.feature_extractor = feature_extractor
# We initialize diffusers pipeline specific attributes (registers modules and config)
all_pipeline_init_args = {
"vae": self.vae,
"unet": self.unet,
"transformer": self.transformer,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
"text_encoder_3": self.text_encoder_3,
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
"scheduler": self.scheduler,
"tokenizer": self.tokenizer,
"tokenizer_2": self.tokenizer_2,
"tokenizer_3": self.tokenizer_3,
"feature_extractor": self.feature_extractor,
"requires_aesthetics_score": requires_aesthetics_score,
"force_zeros_for_empty_prompt": force_zeros_for_empty_prompt,
"add_watermarker": add_watermarker,
}
diffusers_pipeline_args = {}
for key in inspect.signature(self.auto_model_class).parameters.keys():
if key in all_pipeline_init_args:
diffusers_pipeline_args[key] = all_pipeline_init_args[key]
self.auto_model_class.__init__(self, **diffusers_pipeline_args)
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
# would end-up removing the directory containing the underlying ONNX model (and thus failing inference).
self.model_save_dir = model_save_dir
@property
def components(self) -> Dict[str, Optional[Union[ORTSessionMixin, torch.nn.Module]]]:
# TODO: all components should be ORTSessionMixin's at some point
components = {
"vae": self.vae,
"unet": self.unet,
"transformer": self.transformer,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
"text_encoder_3": self.text_encoder_3,
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
}
components = {k: v for k, v in components.items() if v is not None}
return components
def to(self, device: Union[torch.device, str, int]):
"""
Changes the device of the pipeline components to the specified device.
Args:
device (`torch.device` or `str` or `int`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run
the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too.
Returns:
`ORTDiffusionPipeline`: The pipeline with the updated device.
"""
for component in self.components.values():
if isinstance(component, (ORTSessionMixin, ORTParentMixin)):
component.to(device)
return self
@classmethod
def from_pretrained(
cls,
model_name_or_path: Union[str, Path],
# export options
export: bool = False,
# session options
provider: str = "CPUExecutionProvider",
providers: Optional[Sequence[str]] = None,
provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None,
session_options: Optional[SessionOptions] = None,
# inference options
use_io_binding: Optional[bool] = None,
# hub options and preloaded models
**kwargs,
):
"""
Instantiates a [`ORTDiffusionPipeline`] with ONNX Runtime sessions from a pretrained model.
This method can be used to load a model from the Hugging Face Hub or from a local directory.
Args:
model_name_or_path (`str` or `os.PathLike`):
Path to a folder containing the model files or a hub repository id.
export (`bool`, *optional*, defaults to `False`):
Whether to export the model to ONNX format. If set to `True`, the model will be exported and saved
in the specified directory.
provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`):
The execution provider for ONNX Runtime. Can be `"CUDAExecutionProvider"`, `"DmlExecutionProvider"`,
etc.
providers (`Sequence[str]`, *optional*):
A list of execution providers for ONNX Runtime. Overrides `provider`.
provider_options (`Union[Sequence[Dict[str, Any]], Dict[str, Any]]`, *optional*):
Options for each execution provider. Can be a single dictionary for the first provider or a list of
dictionaries for each provider. The order of the dictionaries should match the order of the providers.
session_options (`SessionOptions`, *optional*):
Options for the ONNX Runtime session. Can be used to set optimization levels, graph optimization,
etc.
use_io_binding (`bool`, *optional*):
Whether to use IOBinding for the ONNX Runtime session. If set to `True`, it will use IOBinding for
input and output tensors.
**kwargs:
Can include the following:
- Export arguments (e.g., `slim`, `dtype`, `device`, `no_dynamic_axes`, etc.).
- Hugging Face Hub arguments (e.g., `revision`, `cache_dir`, `force_download`, etc.).
- Preloaded models or sessions for the different components of the pipeline (e.g., `vae_encoder_session`,
`vae_decoder_session`, `unet_session`, `transformer_session`, `image_encoder`, `safety_checker`, etc.).
Returns:
[`ORTDiffusionPipeline`]: The loaded pipeline with ONNX Runtime sessions.
"""
providers, provider_options = prepare_providers_and_provider_options(
provider=provider, providers=providers, provider_options=provider_options
)
hub_kwargs = {
"force_download": kwargs.get("force_download", False),
"resume_download": kwargs.get("resume_download", None),
"local_files_only": kwargs.get("local_files_only", False),
"cache_dir": kwargs.get("cache_dir", None),
"revision": kwargs.get("revision", None),
"proxies": kwargs.get("proxies", None),
"token": kwargs.get("token", None),
}
# get the pipeline config
config = cls.load_config(model_name_or_path, **hub_kwargs)
config = config[0] if isinstance(config, tuple) else config
model_save_tmpdir = None
model_save_path = Path(model_name_or_path)
# export the model if requested
if export:
model_save_tmpdir = TemporaryDirectory()
model_save_path = Path(model_save_tmpdir.name)
export_kwargs = {
"slim": kwargs.pop("slim", False),
"dtype": kwargs.pop("dtype", None),
"device": get_device_for_provider(provider, {}).type,
"no_dynamic_axes": kwargs.pop("no_dynamic_axes", False),
}
main_export(
model_name_or_path=model_name_or_path,
# export related arguments
output=model_save_path,
no_post_process=True,
do_validation=False,
task=cls.task,
# export related arguments
**export_kwargs,
# hub related arguments
**hub_kwargs,
)
# download the model if needed
if not model_save_path.is_dir():
# everything in components subfolders
all_components = {key for key in config.keys() if not key.startswith("_")} | {"vae_encoder", "vae_decoder"}
allow_patterns = {os.path.join(component, "*") for component in all_components}
# plus custom file names
allow_patterns.update(
{
ONNX_WEIGHTS_NAME,
DIFFUSION_PIPELINE_CONFIG_FILE_NAME,
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
}
)
model_save_folder = HfApi(user_agent=http_user_agent()).snapshot_download(
repo_id=str(model_name_or_path),
allow_patterns=allow_patterns,
ignore_patterns=["*.msgpack", "*.safetensors", "*.bin", "*.xml"],
**hub_kwargs,
)
model_save_path = Path(model_save_folder)
model_paths = {
"unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME,
"transformer": model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
"text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / ONNX_WEIGHTS_NAME,
"text_encoder_3": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER / ONNX_WEIGHTS_NAME,
}
models = {}
sessions = {}
for model, path in model_paths.items():
if kwargs.get(model, None) is not None:
# this allows passing a model directly to from_pretrained
models[model] = kwargs.pop(model)
elif kwargs.get(f"{model}_session", None) is not None:
# this allows passing a session directly to from_pretrained
sessions[f"{model}_session"] = kwargs.pop(f"{model}_session")
elif path.is_file():
sessions[f"{model}_session"] = InferenceSession(
path,
providers=providers,
provider_options=provider_options,
sess_options=session_options,
)
submodels = {}
for submodel in {"scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"}:
if kwargs.get(submodel, None) is not None:
submodels[submodel] = kwargs.pop(submodel)
elif config.get(submodel, (None, None))[0] is not None:
library_name, library_classes = config.get(submodel)
library = importlib.import_module(library_name)
class_obj = getattr(library, library_classes)
load_method = getattr(class_obj, "from_pretrained")
# Check if the module is in a subdirectory
if (model_save_path / submodel).is_dir():
submodels[submodel] = load_method(model_save_path / submodel)
else:
submodels[submodel] = load_method(model_save_path)
# Same as DiffusionPipeline.from_pretrained
if cls.__name__ == "ORTDiffusionPipeline":
pipeline_class_name = config["_class_name"]
ort_pipeline_class = _get_ort_class(pipeline_class_name)
else:
ort_pipeline_class = cls
ort_pipeline = ort_pipeline_class(
**sessions,
**submodels,
use_io_binding=use_io_binding,
model_save_dir=model_save_tmpdir,
**models,
**kwargs,
)
ort_pipeline.register_to_config(**config)
ort_pipeline.register_to_config(_name_or_path=config.get("_name_or_path", model_name_or_path))
return ort_pipeline
def save_pretrained(
self,
save_directory: Union[str, Path],
push_to_hub: Optional[bool] = False,
**kwargs,
):
"""
Saves a model and its configuration file to a directory, so that it can be re-loaded using the
[`from_pretrained`] class method.
Args:
save_directory (`Union[str, os.PathLike]`):
Directory to which to save. Will be created if it doesn't exist.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
**kwargs:
Additional keyword arguments passed along to [`~huggingface_hub.create_repo`] and
[`~huggingface_hub.HfApi.upload_folder`] if `push_to_hub` is set to `True`.
"""
model_save_path = Path(save_directory)
model_save_path.mkdir(parents=True, exist_ok=True)
if push_to_hub:
token = kwargs.pop("token", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
self.save_config(model_save_path)
self.scheduler.save_pretrained(model_save_path / "scheduler")
if self.unet is not None:
self.unet.save_pretrained(model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER)
if self.transformer is not None:
self.transformer.save_pretrained(model_save_path / DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER)
if self.vae_encoder is not None:
self.vae_encoder.save_pretrained(model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER)
if self.vae_decoder is not None:
self.vae_decoder.save_pretrained(model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER)
if self.text_encoder is not None:
self.text_encoder.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER)
if self.text_encoder_2 is not None:
self.text_encoder_2.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER)
if self.text_encoder_3 is not None:
self.text_encoder_3.save_pretrained(model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER)
if self.image_encoder is not None:
self.image_encoder.save_pretrained(model_save_path / "image_encoder")
if self.safety_checker is not None:
self.safety_checker.save_pretrained(model_save_path / "safety_checker")
if self.tokenizer is not None:
self.tokenizer.save_pretrained(model_save_path / "tokenizer")
if self.tokenizer_2 is not None:
self.tokenizer_2.save_pretrained(model_save_path / "tokenizer_2")
if self.tokenizer_3 is not None:
self.tokenizer_3.save_pretrained(model_save_path / "tokenizer_3")
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(model_save_path / "feature_extractor")
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
model_card = populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))
self._upload_folder(
save_directory,
repo_id,
token=token,
create_pr=create_pr,
commit_message=commit_message,
)
def __call__(self, *args, **kwargs):
# we do this to keep numpy random states support for now
args = list(args)
for i in range(len(args)):
new_args = np_to_pt_generators(args[i], self.device)
if args[i] is not new_args:
logger.warning(
"Converting numpy random state to torch generator is deprecated. "
"Please pass a torch generator directly to the pipeline."
)
for key, value in kwargs.items():
new_value = np_to_pt_generators(value, self.device)
if value is not new_value:
logger.warning(
"Converting numpy random state to torch generator is deprecated. "
"Please pass a torch generator directly to the pipeline."
)
kwargs[key] = new_value
return self.auto_model_class.__call__(self, *args, **kwargs)
class ORTModelMixin(ORTSessionMixin, ConfigMixin):
config_name: str = CONFIG_NAME
def __init__(
self,
session: "InferenceSession",
parent: "ORTDiffusionPipeline",
use_io_binding: Optional[bool] = None,
):
self.initialize_ort_attributes(session, use_io_binding=use_io_binding)
self.parent = parent
config_file_path = Path(session._model_path).parent / self.config_name
if not config_file_path.is_file():
# config is mandatory for the model part to be used for inference
raise ValueError(f"Configuration file for {self.__class__.__name__} not found at {config_file_path}")
config_dict = self._dict_from_json_file(config_file_path)
self.register_to_config(**config_dict)
def save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the ONNX model and its configuration file to a directory, so that it can be re-loaded using the
[`from_pretrained`] class method.
Args:
save_directory (`Union[str, os.PathLike]`):
Directory to which to save. Will be created if it doesn't exist.
"""
# save onnx model and external data
self.save_session(save_directory)
# save model configuration
self.save_config(save_directory)
class ORTUnet(ORTModelMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# can be missing from models exported long ago
if not hasattr(self.config, "time_cond_proj_dim"):
logger.warning(
"The `time_cond_proj_dim` attribute is missing from the UNet configuration. "
"Please re-export the model with newer version of optimum and diffusers."
)
self.register_to_config(time_cond_proj_dim=None)
if len(self.input_shapes["timestep"]) > 0:
logger.warning(
"The exported unet onnx model expects a non scalar timestep input. "
"We will have to unsqueeze the timestep input at each iteration which might be inefficient. "
"Please re-export the pipeline with newer version of optimum and diffusers to avoid this warning."
)
def forward(
self,
sample: Union[np.ndarray, torch.Tensor],
timestep: Union[np.ndarray, torch.Tensor],
encoder_hidden_states: Union[np.ndarray, torch.Tensor],
timestep_cond: Optional[Union[np.ndarray, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
use_torch = isinstance(sample, torch.Tensor)
if len(self.input_shapes["timestep"]) > 0:
timestep = timestep.unsqueeze(0)
model_inputs = {
"sample": sample,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"timestep_cond": timestep_cond,
**(cross_attention_kwargs or {}),
**(added_cond_kwargs or {}),
}
if self.use_io_binding:
known_output_shapes = {"out_sample": sample.shape}
known_output_buffers = None
if "LatentConsistencyModel" not in self.parent.__class__.__name__:
known_output_buffers = {"out_sample": sample}
output_shapes, output_buffers = self._prepare_io_binding(
model_inputs,
known_output_shapes=known_output_shapes,
known_output_buffers=known_output_buffers,
)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
model_outputs["sample"] = model_outputs.pop("out_sample")
if not return_dict:
return tuple(model_outputs.values())
return ModelOutput(**model_outputs)
class ORTTransformer(ORTModelMixin):
def forward(
self,
hidden_states: Union[np.ndarray, torch.Tensor],
encoder_hidden_states: Union[np.ndarray, torch.Tensor],
pooled_projections: Union[np.ndarray, torch.Tensor],
timestep: Union[np.ndarray, torch.Tensor],
guidance: Optional[Union[np.ndarray, torch.Tensor]] = None,
txt_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
img_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
use_torch = isinstance(hidden_states, torch.Tensor)
model_inputs = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"timestep": timestep,
"guidance": guidance,
"txt_ids": txt_ids,
"img_ids": img_ids,
**(joint_attention_kwargs or {}),
}
if self.use_io_binding:
known_output_shapes = {"out_hidden_states": hidden_states.shape}
known_output_buffers = None
if "Flux" not in self.parent.__class__.__name__:
known_output_buffers = {"out_hidden_states": hidden_states}
output_shapes, output_buffers = self._prepare_io_binding(
model_inputs,
known_output_shapes=known_output_shapes,
known_output_buffers=known_output_buffers,
)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
model_outputs["hidden_states"] = model_outputs.pop("out_hidden_states")
if not return_dict:
return tuple(model_outputs.values())
return ModelOutput(**model_outputs)
class ORTTextEncoder(ORTModelMixin):
def forward(
self,
input_ids: Union[np.ndarray, torch.Tensor],
attention_mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
output_hidden_states: Optional[bool] = None,
return_dict: bool = True,
):
use_torch = isinstance(input_ids, torch.Tensor)
model_inputs = {
"input_ids": input_ids,
}
if self.use_io_binding:
output_shapes, output_buffers = self._prepare_io_binding(model_inputs)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
if output_hidden_states:
model_outputs["hidden_states"] = []
num_layers = self.num_hidden_layers if hasattr(self, "num_hidden_layers") else self.num_decoder_layers
for i in range(num_layers):
model_outputs["hidden_states"].append(model_outputs.pop(f"hidden_states.{i}"))
model_outputs["hidden_states"].append(model_outputs.get("last_hidden_state"))
else:
num_layers = self.num_hidden_layers if hasattr(self, "num_hidden_layers") else self.num_decoder_layers
for i in range(num_layers):
model_outputs.pop(f"hidden_states.{i}", None)
if not return_dict:
return tuple(model_outputs.values())
return ModelOutput(**model_outputs)
class ORTVaeEncoder(ORTModelMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# can be missing from models exported long ago
if not hasattr(self.config, "scaling_factor"):
logger.warning(
"The `scaling_factor` attribute is missing from the VAE encoder configuration. "
"Please re-export the model with newer version of optimum and diffusers to avoid this warning."
)
self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1))
def forward(
self,
sample: Union[np.ndarray, torch.Tensor],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
):
use_torch = isinstance(sample, torch.Tensor)
model_inputs = {
"sample": sample,
}
if self.use_io_binding:
output_shapes, output_buffers = self._prepare_io_binding(model_inputs)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
if "latent_sample" in model_outputs:
model_outputs["latents"] = model_outputs.pop("latent_sample")
if "latent_parameters" in model_outputs:
model_outputs["latent_dist"] = DiagonalGaussianDistribution(
parameters=model_outputs.pop("latent_parameters")
)
if not return_dict:
return tuple(model_outputs.values())
return ModelOutput(**model_outputs)
class ORTVaeDecoder(ORTModelMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# can be missing from models exported long ago
if not hasattr(self.config, "scaling_factor"):
logger.warning(
"The `scaling_factor` attribute is missing from the VAE decoder configuration. "
"Please re-export the model with newer version of optimum and diffusers to avoid this warning."
)
self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1))
def forward(
self,
latent_sample: Union[np.ndarray, torch.Tensor],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
):
use_torch = isinstance(latent_sample, torch.Tensor)
model_inputs = {
"latent_sample": latent_sample,
}
if self.use_io_binding:
output_shapes, output_buffers = self._prepare_io_binding(model_inputs)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
model_outputs = {name: output_buffers[name].view(output_shapes[name]) for name in self.output_names}
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
if "latent_sample" in model_outputs:
model_outputs["latents"] = model_outputs.pop("latent_sample")
if not return_dict:
return tuple(model_outputs.values())
return ModelOutput(**model_outputs)
class ORTVae(ORTParentMixin):
def __init__(self, encoder: Optional[ORTVaeEncoder] = None, decoder: Optional[ORTVaeDecoder] = None):
self.encoder = encoder
self.decoder = decoder
self.initialize_ort_attributes(parts=list(filter(None, {self.encoder, self.decoder})))
def decode(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.encoder(*args, **kwargs)
@property
def config(self):
return self.decoder.config
ORT_PIPELINE_DOCSTRING = r"""
This Pipeline inherits from [`ORTDiffusionPipeline`] and is used to run inference with the ONNX Runtime.
The pipeline can be loaded from a pretrained pipeline using the [`ORTDiffusionPipeline.from_pretrained`] method.
"""
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionPipeline(ORTDiffusionPipeline, StableDiffusionPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline).
"""
task = "text-to-image"
main_input_name = "prompt"
auto_model_class = StableDiffusionPipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionImg2ImgPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline).
"""
task = "image-to-image"
main_input_name = "image"
auto_model_class = StableDiffusionImg2ImgPipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionInpaintPipeline(ORTDiffusionPipeline, StableDiffusionInpaintPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline).
"""
task = "inpainting"
main_input_name = "prompt"
auto_model_class = StableDiffusionInpaintPipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionXLPipeline(ORTDiffusionPipeline, StableDiffusionXLPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline).
"""
task = "text-to-image"
main_input_name = "prompt"
auto_model_class = StableDiffusionXLPipeline
def _get_add_time_ids(
self,
original_size,
crops_coords_top_left,
target_size,
dtype,
text_encoder_projection_dim=None,
):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionXLImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionXLImg2ImgPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline).
"""
task = "image-to-image"
main_input_name = "prompt"
auto_model_class = StableDiffusionXLImg2ImgPipeline
def _get_add_time_ids(
self,
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype,
text_encoder_projection_dim=None,
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
add_neg_time_ids = list(
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
)
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
return add_time_ids, add_neg_time_ids
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusionXLInpaintPipeline(ORTDiffusionPipeline, StableDiffusionXLInpaintPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline).
"""
main_input_name = "image"
task = "inpainting"
auto_model_class = StableDiffusionXLInpaintPipeline
def _get_add_time_ids(
self,
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype,
text_encoder_projection_dim=None,
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
add_neg_time_ids = list(
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
)
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
return add_time_ids, add_neg_time_ids
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTLatentConsistencyModelPipeline(ORTDiffusionPipeline, LatentConsistencyModelPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline).
"""
task = "text-to-image"
main_input_name = "prompt"
auto_model_class = LatentConsistencyModelPipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsistencyModelImg2ImgPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency_img2img#diffusers.LatentConsistencyModelImg2ImgPipeline).
"""
task = "image-to-image"
main_input_name = "image"
auto_model_class = LatentConsistencyModelImg2ImgPipeline
class ORTUnavailablePipeline:
MIN_VERSION = None
def __init__(self, *args, **kwargs):
raise NotImplementedError(
f"The pipeline {self.__class__.__name__} is not available in the current version of `diffusers`. "
f"Please upgrade `diffusers` to {self.MIN_VERSION} or later."
)
if is_diffusers_version(">=", "0.29.0"):
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusion3Pipeline(ORTDiffusionPipeline, StableDiffusion3Pipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Pipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusion3Pipeline).
"""
task = "text-to-image"
main_input_name = "prompt"
auto_model_class = StableDiffusion3Pipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusion3Img2ImgPipeline(ORTDiffusionPipeline, StableDiffusion3Img2ImgPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3Img2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusion3Img2ImgPipeline).
"""
task = "image-to-image"
main_input_name = "image"
auto_model_class = StableDiffusion3Img2ImgPipeline
else:
class ORTStableDiffusion3Pipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.29.0"
class ORTStableDiffusion3Img2ImgPipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.29.0"
if is_diffusers_version(">=", "0.30.0"):
from diffusers import FluxPipeline, StableDiffusion3InpaintPipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTStableDiffusion3InpaintPipeline(ORTDiffusionPipeline, StableDiffusion3InpaintPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusion3InpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusion3InpaintPipeline).
"""
task = "inpainting"
main_input_name = "prompt"
auto_model_class = StableDiffusion3InpaintPipeline
@add_end_docstrings(ORT_PIPELINE_DOCSTRING)
class ORTFluxPipeline(ORTDiffusionPipeline, FluxPipeline):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.FluxPipeline](https://huggingface.co/docs/diffusers/api/pipelines/flux/text2img#diffusers.FluxPipeline).
"""
task = "text-to-image"
main_input_name = "prompt"
auto_model_class = FluxPipeline
else:
class ORTStableDiffusion3InpaintPipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.30.0"
class ORTFluxPipeline(ORTUnavailablePipeline):
MIN_VERSION = "0.30.0"
SUPPORTED_ORT_PIPELINES = [
ORTStableDiffusionPipeline,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionXLPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLInpaintPipeline,
ORTLatentConsistencyModelPipeline,
ORTLatentConsistencyModelImg2ImgPipeline,
ORTStableDiffusion3Pipeline,
ORTStableDiffusion3Img2ImgPipeline,
ORTStableDiffusion3InpaintPipeline,
ORTFluxPipeline,
]
def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True):
for ort_pipeline_class in SUPPORTED_ORT_PIPELINES:
if (
ort_pipeline_class.__name__ == pipeline_class_name
or ort_pipeline_class.auto_model_class.__name__ == pipeline_class_name
):
return ort_pipeline_class
if throw_error_if_not_exist:
raise ValueError(f"ORTDiffusionPipeline can't find a pipeline linked to {pipeline_class_name}")
ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("flux", ORTFluxPipeline),
("latent-consistency", ORTLatentConsistencyModelPipeline),
("stable-diffusion", ORTStableDiffusionPipeline),
("stable-diffusion-3", ORTStableDiffusion3Pipeline),
("stable-diffusion-xl", ORTStableDiffusionXLPipeline),
]
)
ORT_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("latent-consistency", ORTLatentConsistencyModelImg2ImgPipeline),
("stable-diffusion", ORTStableDiffusionImg2ImgPipeline),
("stable-diffusion-3", ORTStableDiffusion3Img2ImgPipeline),
("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline),
]
)
ORT_INPAINT_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", ORTStableDiffusionInpaintPipeline),
("stable-diffusion-3", ORTStableDiffusion3InpaintPipeline),
("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline),
]
)
SUPPORTED_ORT_PIPELINES_MAPPINGS = [
ORT_TEXT2IMAGE_PIPELINES_MAPPING,
ORT_IMAGE2IMAGE_PIPELINES_MAPPING,
ORT_INPAINT_PIPELINES_MAPPING,
]
def _get_task_ort_class(mapping, pipeline_class_name):
def _get_model_name(pipeline_class_name):
for ort_pipelines_mapping in SUPPORTED_ORT_PIPELINES_MAPPINGS:
for model_name, ort_pipeline_class in ort_pipelines_mapping.items():
if (
ort_pipeline_class.__name__ == pipeline_class_name
or ort_pipeline_class.auto_model_class.__name__ == pipeline_class_name
):
return model_name
model_name = _get_model_name(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
if task_class is not None:
return task_class
raise ValueError(f"ORTPipelineForTask can't find a pipeline linked to {pipeline_class_name} for {model_name}")
class ORTPipelineForTask(ConfigMixin):
config_name = "model_index.json"
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTDiffusionPipeline:
load_config_kwargs = {
"force_download": kwargs.get("force_download", False),
"resume_download": kwargs.get("resume_download", None),
"local_files_only": kwargs.get("local_files_only", False),
"cache_dir": kwargs.get("cache_dir", None),
"revision": kwargs.get("revision", None),
"proxies": kwargs.get("proxies", None),
"token": kwargs.get("token", None),
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
config = config[0] if isinstance(config, tuple) else config
class_name = config["_class_name"]
ort_pipeline_class = _get_task_ort_class(cls.ort_pipelines_mapping, class_name)
return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs)
class ORTPipelineForText2Image(ORTPipelineForTask):
auto_model_class = AutoPipelineForText2Image
ort_pipelines_mapping = ORT_TEXT2IMAGE_PIPELINES_MAPPING
class ORTPipelineForImage2Image(ORTPipelineForTask):
auto_model_class = AutoPipelineForImage2Image
ort_pipelines_mapping = ORT_IMAGE2IMAGE_PIPELINES_MAPPING
class ORTPipelineForInpainting(ORTPipelineForTask):
auto_model_class = AutoPipelineForInpainting
ort_pipelines_mapping = ORT_INPAINT_PIPELINES_MAPPING