optimum/exporters/onnx/model_configs.py (2,222 lines of code) (raw):

# coding=utf-8 # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model specific ONNX configurations.""" import math import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union from packaging import version from transformers.utils import is_tf_available from ...utils import ( DEFAULT_DUMMY_SHAPES, ASTDummyAudioInputGenerator, BartDummyTextInputGenerator, BloomDummyPastKeyValuesGenerator, Dinov2DummyInputGenerator, DummyCodegenDecoderTextInputGenerator, DummyDecisionTransformerInputGenerator, DummyDecoderTextInputGenerator, DummyEncodecInputGenerator, DummyFluxTransformerTextInputGenerator, DummyFluxTransformerVisionInputGenerator, DummyInputGenerator, DummyIntGenerator, DummyPastKeyValuesGenerator, DummyPatchTSTInputGenerator, DummyPix2StructInputGenerator, DummyPointsGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummySpeechT5InputGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, DummyTransformerTextInputGenerator, DummyTransformerTimestepInputGenerator, DummyTransformerVisionInputGenerator, DummyVisionEmbeddingsGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator, DummyVisionInputGenerator, DummyXPathSeqInputGenerator, FalconDummyPastKeyValuesGenerator, GemmaDummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, LongformerDummyTextInputGenerator, MCTCTDummyAudioInputGenerator, MistralDummyPastKeyValuesGenerator, NormalizedConfig, NormalizedEncoderDecoderConfig, NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedTextConfigWithGQA, NormalizedTimeSeriesForecastingConfig, NormalizedVisionConfig, PerceiverDummyInputGenerator, Speech2TextDummyAudioInputGenerator, T5DummySeq2SeqPastKeyValuesGenerator, VitPoseDummyInputGenerator, is_diffusers_available, is_diffusers_version, is_transformers_version, logging, ) from ...utils.normalized_config import NormalizedConfigManager from ..tasks import TasksManager from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from .config import ( AudioOnnxConfig, AudioToTextOnnxConfig, EncoderDecoderBaseOnnxConfig, TextAndVisionOnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig, VisionOnnxConfig, ) from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ( CLIPModelPatcher, FalconModelPatcher, MgpstrModelPatcher, MistralModelPatcher, MusicgenModelPatcher, SAMModelPatcher, SentenceTransformersCLIPPatcher, SentenceTransformersTransformerPatcher, SpeechT5ModelPatcher, VisionEncoderDecoderPatcher, VitPoseModelPatcher, WavLMModelPatcher, ) # TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization if TYPE_CHECKING: from transformers import PretrainedConfig from transformers.modeling_utils import PreTrainedModel if is_tf_available(): from transformers.modeling_tf_utils import TFPreTrainedModel if is_diffusers_available(): from diffusers import ModelMixin logger = logging.get_logger(__name__) COMMON_TEXT_TASKS = [ "feature-extraction", "fill-mask", "multiple-choice", "question-answering", "text-classification", "token-classification", ] COMMON_TEXT_GENERATION_TASKS = [ "feature-extraction", "feature-extraction-with-past", "text-generation", "text-generation-with-past", ] COMMON_TEXT2TEXT_GENERATION_TASKS = COMMON_TEXT_GENERATION_TASKS + [ "text2text-generation", "text2text-generation-with-past", ] register_tasks_manager_onnx = TasksManager.create_register("onnx") @register_tasks_manager_onnx("bert", *COMMON_TEXT_TASKS) class BertOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.task == "multiple-choice": dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"} else: dynamic_axis = {0: "batch_size", 1: "sequence_length"} return { "input_ids": dynamic_axis, "attention_mask": dynamic_axis, "token_type_ids": dynamic_axis, } @register_tasks_manager_onnx("albert", *COMMON_TEXT_TASKS) class AlbertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx("convbert", *COMMON_TEXT_TASKS) class ConvBertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("electra", *COMMON_TEXT_TASKS) class ElectraOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("roformer", *COMMON_TEXT_TASKS) class RoFormerOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("squeezebert", *COMMON_TEXT_TASKS) class SqueezeBertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("mobilebert", *COMMON_TEXT_TASKS) class MobileBertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("nystromformer", *COMMON_TEXT_TASKS) class NystromformerOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("xlm", *COMMON_TEXT_TASKS) class XLMOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("splinter", *["feature-extraction", "question-answering"]) class SplinterOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("rembert", *COMMON_TEXT_TASKS) class RemBertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("longformer", *COMMON_TEXT_TASKS) class LongformerOnnxConfig(BertOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (LongformerDummyTextInputGenerator,) DEFAULT_ONNX_OPSET = 14 @property def inputs(self) -> Dict[str, Dict[int, str]]: inputs = super().inputs inputs["global_attention_mask"] = inputs["attention_mask"] return inputs @register_tasks_manager_onnx("megatron-bert", *COMMON_TEXT_TASKS) class MegatronBertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("distilbert", *COMMON_TEXT_TASKS) class DistilBertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for transformers>=4.46.0 @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.task == "multiple-choice": dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"} else: dynamic_axis = {0: "batch_size", 1: "sequence_length"} return {"input_ids": dynamic_axis, "attention_mask": dynamic_axis} @register_tasks_manager_onnx( "modernbert", *[ "feature-extraction", "fill-mask", "text-classification", "token-classification", ], ) class ModernBertOnnxConfig(DistilBertOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.48.0") @register_tasks_manager_onnx("mpnet", *COMMON_TEXT_TASKS) class MPNetOnnxConfig(DistilBertOnnxConfig): DEFAULT_ONNX_OPSET = 12 # For lower opsets, results in: Type 'tensor(int64)' of input parameter (/0/auto_model/encoder/Add_1_output_0) of operator (Min) in node (/0/auto_model/encoder/Min) is invalid. @register_tasks_manager_onnx("roberta", *COMMON_TEXT_TASKS) class RobertaOnnxConfig(DistilBertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx("camembert", *COMMON_TEXT_TASKS) class CamembertOnnxConfig(DistilBertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx("flaubert", *COMMON_TEXT_TASKS) class FlaubertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("ibert", *COMMON_TEXT_TASKS) class IBertOnnxConfig(DistilBertOnnxConfig): pass @register_tasks_manager_onnx("xlm-roberta", *COMMON_TEXT_TASKS) class XLMRobertaOnnxConfig(DistilBertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx( "deberta", *["feature-extraction", "fill-mask", "text-classification", "token-classification", "question-answering"], ) class DebertaOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 12 @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs if self._config.type_vocab_size == 0: common_inputs.pop("token_type_ids") return common_inputs @register_tasks_manager_onnx( "markuplm", *["feature-extraction", "text-classification", "token-classification", "question-answering"] ) class MarkupLMOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTextInputGenerator, DummyXPathSeqInputGenerator, ) @property def inputs(self) -> Dict[str, Dict[int, str]]: dynamic_axis = {0: "batch_size", 1: "sequence_length"} xpath_dynamic_axis = {0: "batch_size", 1: "sequence_length", 2: "max_depth"} return { "input_ids": dynamic_axis, "attention_mask": dynamic_axis, "token_type_ids": dynamic_axis, "xpath_subs_seq": xpath_dynamic_axis, "xpath_tags_seq": xpath_dynamic_axis, } @register_tasks_manager_onnx("deberta-v2", *COMMON_TEXT_TASKS) class DebertaV2OnnxConfig(DebertaOnnxConfig): pass @register_tasks_manager_onnx( "esm", *["feature-extraction", "fill-mask", "text-classification", "token-classification"] ) class EsmOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 12 @property def inputs(self) -> Dict[str, Dict[int, str]]: dynamic_axis = {0: "batch_size", 1: "sequence_length"} return { "input_ids": dynamic_axis, "attention_mask": dynamic_axis, } @register_tasks_manager_onnx("gpt2", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"]) class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") @register_tasks_manager_onnx("gptj", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"]) class GPTJOnnxConfig(GPT2OnnxConfig): pass @register_tasks_manager_onnx("codegen", *COMMON_TEXT_GENERATION_TASKS) class CodeGenOnnxConfig(GPT2OnnxConfig): pass @register_tasks_manager_onnx("imagegpt", *["feature-extraction", "image-classification"]) class ImageGPTOnnxConfig(GPT2OnnxConfig): pass @register_tasks_manager_onnx("decision_transformer", *["feature-extraction", "reinforcement-learning"]) class DecisionTransformerOnnxConfig(OnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,) NORMALIZED_CONFIG_CLASS = NormalizedConfig @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "states": {0: "batch_size", 1: "sequence_length"}, "actions": {0: "batch_size", 1: "sequence_length"}, "timesteps": {0: "batch_size", 1: "sequence_length"}, "returns_to_go": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "state_preds": {0: "batch_size", 1: "sequence_length"}, "action_preds": {0: "batch_size", 1: "sequence_length"}, "return_preds": {0: "batch_size", 1: "sequence_length"}, "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, } @register_tasks_manager_onnx("gpt_neo", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads") @register_tasks_manager_onnx("gpt_neox", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig # OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46 if is_transformers_version(">=", "4.46.0"): @register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"]) class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig else: @register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"]) class OPTOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig @register_tasks_manager_onnx("llama", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Llama now uses F.scaled_dot_product_attention by default for torch>=2.1.1. DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig @register_tasks_manager_onnx("olmo", *COMMON_TEXT_GENERATION_TASKS) class OlmoOnnxConfig(LlamaOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 MIN_TRANSFORMERS_VERSION = version.parse("4.40.0") @register_tasks_manager_onnx("olmo2", *COMMON_TEXT_GENERATION_TASKS) class Olmo2OnnxConfig(OlmoOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.47.0") @register_tasks_manager_onnx("qwen2", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"]) class Qwen2OnnxConfig(LlamaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.37.0") @register_tasks_manager_onnx("qwen3", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class Qwen3OnnxConfig(LlamaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.51.0") @register_tasks_manager_onnx( "qwen3_moe", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"] ) class Qwen3MoeOnnxConfig(LlamaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.51.0") @register_tasks_manager_onnx("gemma", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class GemmaOnnxConfig(LlamaOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator MIN_TRANSFORMERS_VERSION = version.parse("4.38.0") @register_tasks_manager_onnx("granite", *COMMON_TEXT_GENERATION_TASKS) class GraniteOnnxConfig(LlamaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.45.0") MIN_TORCH_VERSION = version.parse("2.5.0") @register_tasks_manager_onnx("phi", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig MIN_TRANSFORMERS_VERSION = version.parse("4.42.0") @register_tasks_manager_onnx("phi3", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class Phi3OnnxConfig(PhiOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( MistralDummyPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA MIN_TRANSFORMERS_VERSION = version.parse("4.50.0") @register_tasks_manager_onnx("internlm2", *["text-generation", "text-generation-with-past"]) class InternLM2OnnxConfig(LlamaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") @register_tasks_manager_onnx("mistral", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 DUMMY_INPUT_GENERATOR_CLASSES = ( MistralDummyPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) _MODEL_PATCHER = MistralModelPatcher @register_tasks_manager_onnx("mpt", *["text-generation", "text-generation-with-past", "text-classification"]) class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. DEFAULT_ONNX_OPSET = 13 # TODO: fix inference for transformers < v4.41 for beam_search > 1 MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) @register_tasks_manager_onnx("bloom", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"]) class BloomOnnxConfig(TextDecoderOnnxConfig): # Bloom does not require position_ids input. DUMMY_INPUT_GENERATOR_CLASSES = ( BloomDummyPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DEFAULT_ONNX_OPSET = 14 # Bloom uses F.scaled_dot_product_attention MIN_TRANSFORMERS_VERSION = version.parse("4.44.0") DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") @register_tasks_manager_onnx( "gpt_bigcode", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"] ) class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( GPTBigCodeDummyPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DEFAULT_ONNX_OPSET = 14 # GPT BigCode now uses F.scaled_dot_product_attention by default for torch>=2.1.1. DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gpt_bigcode") def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): if direction not in ["inputs", "outputs"]: raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') if direction == "inputs": decoder_sequence_name = "past_sequence_length" name = "past_key_values" else: decoder_sequence_name = "past_sequence_length + 1" name = "present" for i in range(self._normalized_config.num_layers): # No dim for `n_head` when using multi-query attention inputs_or_outputs[f"{name}.{i}.key_value"] = { 0: "batch_size", 1: decoder_sequence_name, } def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.key_value"] = t @register_tasks_manager_onnx("falcon", *COMMON_TEXT_GENERATION_TASKS + ["question-answering", "token-classification"]) class FalconOnnxConfig(TextDecoderOnnxConfig): # This is due to the cache refactoring for Falcon in 4.36 MIN_TRANSFORMERS_VERSION = version.parse("4.35.99") DUMMY_INPUT_GENERATOR_CLASSES = ( FalconDummyPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention NORMALIZED_CONFIG_CLASS = NormalizedTextConfig DUMMY_PKV_GENERATOR_CLASS = FalconDummyPastKeyValuesGenerator # we need to set output_attentions=True in the model input to avoid calling # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export _MODEL_PATCHER = FalconModelPatcher def __init__( self, config: "PretrainedConfig", task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, legacy: bool = False, ): super().__init__( config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, use_past=use_past, use_past_in_inputs=use_past_in_inputs, preprocessors=preprocessors, legacy=legacy, ) # For some reason Falcon config.num_kv_heads can not be trusted, see in Transformers: # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L337 self._normalized_config.num_kv_heads = ( self._normalized_config.num_kv_heads if (self._normalized_config.new_decoder_architecture or not self._normalized_config.multi_query) else 1 ) @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs if not self.legacy and not self._config.alibi and self.task in ["text-generation", "feature-extraction"]: # When alibi is used, position_ids are not used in Falcon. # Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116 common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} return common_inputs @register_tasks_manager_onnx( "t5", *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], ) class T5OnnxConfig(TextSeq2SeqOnnxConfig): DEFAULT_ONNX_OPSET = 14 # T5 uses aten::triu that requires opset>=14 DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[:-1] + ( T5DummySeq2SeqPastKeyValuesGenerator, ) DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( hidden_size="d_model", num_attention_heads="num_heads", encoder_num_layers="num_layers", decoder_num_layers="num_decoder_layers", key_value_dim="d_kv", allow_new=True, ) def generate_dummy_inputs_for_validation( self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None ) -> Dict[str, Any]: if self._behavior is ConfigBehavior.DECODER: reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids") if onnx_input_names is not None: if "encoder_outputs" in reference_model_inputs: if "encoder_hidden_states" in onnx_input_names: reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] else: reference_model_inputs.pop("encoder_outputs") else: # TODO: remove this else in optimum 2.0 and make onnx_input_names a required argument # T5 requires encoder_hidden_states as an input for both the without/with past models, # which is different than other architectures that require it only for the without past case reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] return super().generate_dummy_inputs_for_validation(reference_model_inputs) @register_tasks_manager_onnx( "mt5", *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], ) class MT5OnnxConfig(T5OnnxConfig): ATOL_FOR_VALIDATION = 1e-4 @register_tasks_manager_onnx( "longt5", *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], ) class LongT5OnnxConfig(T5OnnxConfig): DEFAULT_ONNX_OPSET = 14 @register_tasks_manager_onnx( "m2m_100", *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], ) class M2M100OnnxConfig(TextSeq2SeqOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", num_layers="decoder_layers", # Used for the text-generation task past key values input generation. encoder_num_attention_heads="encoder_attention_heads", decoder_num_attention_heads="decoder_attention_heads", eos_token_id="eos_token_id", ) DUMMY_INPUT_GENERATOR_CLASSES = ( BartDummyTextInputGenerator, { "feature-extraction": DummySeq2SeqDecoderTextInputGenerator, "text-generation": DummyDecoderTextInputGenerator, }, { "feature-extraction": DummySeq2SeqPastKeyValuesGenerator, "text-generation": DummyPastKeyValuesGenerator, }, ) def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( self.task, self._normalized_config, **kwargs ) task = "feature-extraction" if self.task != "text-generation" else "text-generation" dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1][task]( self.task, self._normalized_config, **kwargs ) if self.task != "text-generation": kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task]( self.task, self._normalized_config, **kwargs ) dummy_inputs_generators = [ dummy_text_input_generator, dummy_decoder_text_input_generator, dummy_seq2seq_past_key_values_generator, ] return dummy_inputs_generators @property def inputs_for_default_and_seq2seq_lm(self): return super().inputs @property def inputs_for_causal_lm(self): if self.use_past_in_inputs: common_inputs = { "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "past_sequence_length + 1"}, } for i in range(self._normalized_config.decoder_num_layers): common_inputs[f"past_key_values.{i}.key"] = { 0: "batch_size", 2: "past_sequence_length", } common_inputs[f"past_key_values.{i}.value"] = { 0: "batch_size", 2: "past_sequence_length", } else: common_inputs = { "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, } return common_inputs @property def inputs_for_other_tasks(self): return { "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, } @property def inputs(self) -> Dict[str, Dict[int, str]]: inputs_properties = { "feature-extraction": self.inputs_for_default_and_seq2seq_lm, "text2text-generation": self.inputs_for_default_and_seq2seq_lm, "text-generation": self.inputs_for_causal_lm, "other": self.inputs_for_other_tasks, } return inputs_properties.get(self.task, inputs_properties["other"]) @property def outputs(self) -> Dict[str, Dict[int, str]]: if self.task in ["feature-extraction", "text2text-generation"]: common_outputs = super().outputs else: common_outputs = super(OnnxConfigWithPast, self).outputs if self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. for i in range( self._normalized_config.encoder_num_layers if self.task != "text-generation" else self._normalized_config.decoder_num_layers ): common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"} common_outputs[f"present.{i}.value"] = { 0: "batch_size", 2: "past_sequence_length + sequence_length", } return common_outputs def generate_dummy_inputs(self, framework: str = "pt", **kwargs): # This will handle the attention mask padding when Bart is used for text-generation. if self.task == "text-generation": self.PAD_ATTENTION_MASK_TO_PAST = True dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) # Setting it back to the default version. self.PAD_ATTENTION_MASK_TO_PAST = False return dummy_inputs def flatten_past_key_values(self, flattened_output, name, idx, t): if self.task in ["feature-extraction", "text2text-generation"]: flattened_output = super().flatten_past_key_values(flattened_output, name, idx, t) else: flattened_output = super(OnnxSeq2SeqConfigWithPast, self).flatten_past_key_values( flattened_output, name, idx, t ) @register_tasks_manager_onnx( "bart", *COMMON_TEXT2TEXT_GENERATION_TASKS + ["text-classification", "question-answering"] ) class BartOnnxConfig(M2M100OnnxConfig): DEFAULT_ONNX_OPSET = 14 # Bart now uses F.scaled_dot_product_attention by default for torch>=2.1.1. MIN_TORCH_VERSION = version.parse("2.1.2") @register_tasks_manager_onnx( "mbart", *COMMON_TEXT2TEXT_GENERATION_TASKS + ["text-classification", "question-answering"] ) class MBartOnnxConfig(BartOnnxConfig): pass @register_tasks_manager_onnx("blenderbot", *COMMON_TEXT2TEXT_GENERATION_TASKS) class BlenderbotOnnxConfig(BartOnnxConfig): pass @register_tasks_manager_onnx("blenderbot-small", *COMMON_TEXT2TEXT_GENERATION_TASKS) class BlenderbotSmallOnnxConfig(BartOnnxConfig): pass @register_tasks_manager_onnx("big_bird", *COMMON_TEXT_TASKS) class BigBirdOnnxConfig(DistilBertOnnxConfig): pass @register_tasks_manager_onnx( "bigbird_pegasus", *COMMON_TEXT2TEXT_GENERATION_TASKS + ["text-classification", "question-answering"] ) class BigBirdPegasusOnnxConfig(BartOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: inputs = super().inputs if self._config.attention_type == "block_sparse": # BigBirdPegasusEncoder creates its own attention_mask internally # https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py#L1875 inputs.pop("attention_mask", None) return inputs @register_tasks_manager_onnx("pegasus", *COMMON_TEXT2TEXT_GENERATION_TASKS) class PegasusOnnxConfig(BartOnnxConfig): pass @register_tasks_manager_onnx("marian", *COMMON_TEXT2TEXT_GENERATION_TASKS) class MarianOnnxConfig(BartOnnxConfig): pass @register_tasks_manager_onnx("vit", *["feature-extraction", "image-classification", "masked-im"]) class ViTOnnxConfig(VisionOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig MIN_TORCH_VERSION = version.parse("1.11") DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self.task == "feature-extraction": common_outputs["last_hidden_state"] = {0: "batch_size"} return common_outputs @register_tasks_manager_onnx("vitpose", *["keypoint-detection"]) class VitPoseOnnxConfig(ViTOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (VitPoseDummyInputGenerator,) ATOL_FOR_VALIDATION = 1e-4 _MODEL_PATCHER = VitPoseModelPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"pixel_values": {0: "batch_size"}} @register_tasks_manager_onnx("cvt", *["feature-extraction", "image-classification"]) class CvTOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 13 ATOL_FOR_VALIDATION = 1e-2 @register_tasks_manager_onnx("levit", *["feature-extraction", "image-classification"]) class LevitOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("deit", *["feature-extraction", "image-classification", "masked-im"]) class DeiTOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx("beit", *["feature-extraction", "image-classification"]) class BeitOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx("convnext", *["feature-extraction", "image-classification"]) class ConvNextOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("convnextv2", *["feature-extraction", "image-classification"]) class ConvNextV2OnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("hiera", *["feature-extraction", "image-classification"]) class HieraOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("pvt", *["feature-extraction", "image-classification"]) class PvtOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("vit_mae", *["feature-extraction"]) class VitMAEOnnxConfig(ViTOnnxConfig): # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported. # Support for this operator was added in version 14, try exporting with this version. DEFAULT_ONNX_OPSET = 14 @register_tasks_manager_onnx("vit_msn", *["feature-extraction", "image-classification"]) class VitMSNOnnxConfig(ViTOnnxConfig): # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported. # Support for this operator was added in version 14, try exporting with this version. DEFAULT_ONNX_OPSET = 14 @register_tasks_manager_onnx("dinov2", *["feature-extraction", "image-classification"]) class Dinov2OnnxConfig(ViTOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,) @register_tasks_manager_onnx("mobilevit", *["feature-extraction", "image-classification", "image-segmentation"]) class MobileViTOnnxConfig(ViTOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("regnet", *["feature-extraction", "image-classification"]) class RegNetOnnxConfig(ViTOnnxConfig): # This config has the same inputs as ViTOnnxConfig DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("resnet", *["feature-extraction", "image-classification"]) class ResNetOnnxConfig(ViTOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("detr", *["feature-extraction", "object-detection", "image-segmentation"]) class DetrOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 12 @property def outputs(self) -> Dict[str, Dict[int, str]]: if self.task == "image-segmentation": return { "logits": {0: "batch_size", 1: "num_queries"}, "pred_masks": {0: "batch_size", 1: "num_queries"}, } else: return super().outputs @register_tasks_manager_onnx("table-transformer", *["feature-extraction", "object-detection"]) class TableTransformerOnnxConfig(DetrOnnxConfig): pass @register_tasks_manager_onnx("yolos", *["feature-extraction", "object-detection"]) class YolosOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx("swin", *["feature-extraction", "image-classification", "masked-im"]) class SwinOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("swinv2", *["feature-extraction", "image-classification", "masked-im"]) class SwinV2OnnxConfig(SwinOnnxConfig): pass @register_tasks_manager_onnx("swin2sr", *["feature-extraction", "image-to-image"]) class Swin2srOnnxConfig(SwinOnnxConfig): pass @register_tasks_manager_onnx( "dpt", *["feature-extraction", "depth-estimation", "image-segmentation", "semantic-segmentation"] ) class DptOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 14 @register_tasks_manager_onnx("glpn", *["feature-extraction", "depth-estimation"]) class GlpnOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("poolformer", *["feature-extraction", "image-classification"]) class PoolFormerOnnxConfig(ViTOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig ATOL_FOR_VALIDATION = 2e-3 DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx( "segformer", *["feature-extraction", "image-classification", "image-segmentation", "semantic-segmentation"] ) class SegformerOnnxConfig(YolosOnnxConfig): @property def outputs(self) -> Dict[str, Dict[int, str]]: outputs = super().outputs if self.task == "image-segmentation": outputs["logits"] = {0: "batch_size"} return outputs @register_tasks_manager_onnx("mobilenet_v1", *["feature-extraction", "image-classification"]) class MobileNetV1OnnxConfig(ViTOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 11 @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"pixel_values": {0: "batch_size"}} @register_tasks_manager_onnx("mobilenet_v2", *["feature-extraction", "image-classification"]) class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig): pass @register_tasks_manager_onnx("maskformer", *["feature-extraction", "image-segmentation"]) class MaskFormerOnnxConfig(ViTOnnxConfig): # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::einsum' to ONNX opset version 11 is not supported. # Support for this operator was added in version 12, try exporting with this version. DEFAULT_ONNX_OPSET = 12 @property def outputs(self) -> Dict[str, Dict[int, str]]: if self.task == "image-segmentation": return { "class_queries_logits": {0: "batch_size", 1: "num_queries"}, "masks_queries_logits": {0: "batch_size", 1: "num_queries", 2: "height", 3: "width"}, } else: return super().outputs @property def torch_to_onnx_output_map(self) -> Dict[str, str]: return { "transformer_decoder_last_hidden_state": "last_hidden_state", } @register_tasks_manager_onnx("donut-swin", *["feature-extraction"]) class DonutSwinOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("default-timm-config", *["image-classification"], library_name="timm") class TimmDefaultOnnxConfig(ViTOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 DEFAULT_ONNX_OPSET = 12 def rename_ambiguous_inputs(self, inputs): # The input name in the model signature is `x, hence the export input name is updated. model_inputs = {} model_inputs["x"] = inputs["pixel_values"] return model_inputs @property def torch_to_onnx_input_map(self) -> Dict[str, str]: return {"x": "pixel_values"} @register_tasks_manager_onnx("mgp-str", *["feature-extraction", "image-to-text"]) class MgpstrOnnxConfig(ViTOnnxConfig): _MODEL_PATCHER = MgpstrModelPatcher @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "char_logits": {0: "batch_size"}, "bpe_logits": {0: "batch_size"}, "wp_logits": {0: "batch_size"}, } @register_tasks_manager_onnx("efficientnet", *["feature-extraction", "image-classification"]) class EfficientNetOnnxConfig(ViTOnnxConfig): @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self.task == "image-classification": common_outputs["logits"] = {0: "batch_size", 1: "num_classes"} return common_outputs @register_tasks_manager_onnx( "transformer", *["feature-extraction", "sentence-similarity"], library_name="sentence_transformers" ) class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig DEFAULT_ONNX_OPSET = 14 # Some bottleneck transformers models require a specific ONNX opset to be successfully exported. We put a rather high opset here for the export to work for all architectures. # we need to set output_attentions=True in the model input to avoid calling # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export # due to the op torch.nn.functional.multi_head_attention_forward used for WavLM _MODEL_PATCHER = SentenceTransformersTransformerPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "token_embeddings": {0: "batch_size", 1: "sequence_length"}, "sentence_embedding": {0: "batch_size"}, } class CLIPNormalizedConfig(NormalizedTextAndVisionConfig): TEXT_CONFIG = "text_config" VISION_CONFIG = "vision_config" @register_tasks_manager_onnx("clip_vision_model", *["feature-extraction"]) class CLIPVisionModelOnnxConfig(VisionOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig _MODEL_PATCHER = CLIPModelPatcher DEFAULT_ONNX_OPSET = 14 # scaled_dot_product_attention support was added in opset 14 @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs common_outputs["last_hidden_state"] = {0: "batch_size"} common_outputs["pooler_output"] = {0: "batch_size"} return common_outputs @register_tasks_manager_onnx("clip", *["feature-extraction", "zero-shot-image-classification"]) class CLIPOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig _MODEL_PATCHER = CLIPModelPatcher DEFAULT_ONNX_OPSET = 14 # scaled_dot_product_attention support was added in opset 14 @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "input_ids": {0: "text_batch_size", 1: "sequence_length"}, "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}, "attention_mask": {0: "text_batch_size", 1: "sequence_length"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "logits_per_image": {0: "image_batch_size", 1: "text_batch_size"}, "logits_per_text": {0: "text_batch_size", 1: "image_batch_size"}, "text_embeds": {0: "text_batch_size"}, "image_embeds": {0: "image_batch_size"}, } @register_tasks_manager_onnx( "clip", *["feature-extraction", "sentence-similarity"], library_name="sentence_transformers" ) class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig): _MODEL_PATCHER = SentenceTransformersCLIPPatcher @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "text_embeds": {0: "text_batch_size"}, "image_embeds": {0: "image_batch_size"}, } @register_tasks_manager_onnx("clip-text-with-projection", *["feature-extraction"], library_name="diffusers") class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( vocab_size="vocab_size", sequence_length="max_position_embeddings", num_layers="num_hidden_layers", allow_new=True, ) _MODEL_PATCHER = CLIPModelPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "input_ids": {0: "batch_size", 1: "sequence_length"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = { "text_embeds": {0: "batch_size", 1: "sequence_length"}, "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, } if self._normalized_config.output_hidden_states: for i in range(self._normalized_config.num_layers + 1): common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} return common_outputs @register_tasks_manager_onnx("clip-text", *["feature-extraction"], library_name="diffusers") class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig): _MODEL_PATCHER = CLIPModelPatcher @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = { "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, "pooler_output": {0: "batch_size"}, } if self._normalized_config.output_hidden_states: for i in range(self._normalized_config.num_layers + 1): common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} return common_outputs class SiglipNormalizedConfig(CLIPNormalizedConfig): pass @register_tasks_manager_onnx("chinese_clip", *["feature-extraction", "zero-shot-image-classification"]) class ChineseCLIPOnnxConfig(CLIPOnnxConfig): pass @register_tasks_manager_onnx("siglip", *["feature-extraction", "zero-shot-image-classification"]) class SiglipOnnxConfig(CLIPOnnxConfig): NORMALIZED_CONFIG_CLASS = SiglipNormalizedConfig # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 13 is not supported. # Support for this operator was added in version 14, try exporting with this version. DEFAULT_ONNX_OPSET = 14 @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "input_ids": {0: "text_batch_size", 1: "sequence_length"}, "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}, # NOTE: No attention_mask } @register_tasks_manager_onnx("siglip-text-with-projection", *["feature-extraction"]) class SiglipTextWithProjectionOnnxConfig(CLIPTextWithProjectionOnnxConfig): pass @register_tasks_manager_onnx("siglip-text", *["feature-extraction"]) class SiglipTextOnnxConfig(CLIPTextOnnxConfig): pass @register_tasks_manager_onnx("siglip_vision_model", *["feature-extraction"]) class SiglipVisionModelOnnxConfig(CLIPVisionModelOnnxConfig): # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported. # Support for this operator was added in version 14, try exporting with this version. DEFAULT_ONNX_OPSET = 14 @register_tasks_manager_onnx("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers") class UNetOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( image_size="sample_size", num_channels="in_channels", hidden_size="cross_attention_dim", vocab_size="norm_num_groups", allow_new=True, ) DUMMY_INPUT_GENERATOR_CLASSES = ( DummyVisionInputGenerator, DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, ) @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = { "sample": {0: "batch_size", 2: "height", 3: "width"}, "timestep": {}, # a scalar with no dimension "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, } # TODO : add addition_embed_type == text_image, image and image_embeds # https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671 if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": common_inputs["text_embeds"] = {0: "batch_size"} common_inputs["time_ids"] = {0: "batch_size"} if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: common_inputs["timestep_cond"] = {0: "batch_size"} return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "out_sample": {0: "batch_size", 2: "height", 3: "width"}, } @property def torch_to_onnx_output_map(self) -> Dict[str, str]: return { "sample": "out_sample", } def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": dummy_inputs["added_cond_kwargs"] = { "text_embeds": dummy_inputs.pop("text_embeds"), "time_ids": dummy_inputs.pop("time_ids"), } return dummy_inputs def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: inputs = super().ordered_inputs(model=model) # to fix mismatch between model forward signature and expected inputs # a dictionnary of additional embeddings `added_cond_kwargs` is expected depending on config.addition_embed_type if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": inputs["text_embeds"] = self.inputs["text_embeds"] inputs["time_ids"] = self.inputs["time_ids"] return inputs @register_tasks_manager_onnx("vae-encoder", *["semantic-segmentation"], library_name="diffusers") class VaeEncoderOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 3e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( num_channels="in_channels", image_size="sample_size", allow_new=True ) @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "sample": {0: "batch_size", 2: "sample_height", 3: "sample_width"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: down_sampling_factor = 2 ** (len(self._normalized_config.down_block_types) - 1) return { "latent_parameters": { 0: "batch_size", 2: f"sample_height / {down_sampling_factor}", 3: f"sample_width / {down_sampling_factor}", }, } @register_tasks_manager_onnx("vae-decoder", *["semantic-segmentation"], library_name="diffusers") class VaeDecoderOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 3e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(num_channels="latent_channels", allow_new=True) @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "latent_sample": {0: "batch_size", 2: "latent_height", 3: "latent_width"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: upsampling_factor = 2 ** (len(self._normalized_config.up_block_types) - 1) return { "sample": { 0: "batch_size", 2: f"latent_height * {upsampling_factor}", 3: f"latent_width * {upsampling_factor}", }, } @register_tasks_manager_onnx("t5-encoder", *["feature-extraction"], library_name="diffusers") class T5EncoderOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 12 # int64 was supported since opset 12 @property def inputs(self): return { "input_ids": {0: "batch_size", 1: "sequence_length"}, } @property def outputs(self): return { "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, } @register_tasks_manager_onnx("sd3-transformer-2d", *["semantic-segmentation"], library_name="diffusers") class SD3TransformerOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTransformerTimestepInputGenerator, DummyTransformerVisionInputGenerator, DummyTransformerTextInputGenerator, ) NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( image_size="sample_size", num_channels="in_channels", vocab_size="attention_head_dim", hidden_size="joint_attention_dim", projection_size="pooled_projection_dim", allow_new=True, ) @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = { "hidden_states": {0: "batch_size", 2: "height", 3: "width"}, "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, "pooled_projections": {0: "batch_size"}, "timestep": {0: "step"}, } return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "out_hidden_states": {0: "batch_size", 2: "height", 3: "width"}, } @property def torch_to_onnx_output_map(self) -> Dict[str, str]: return { "sample": "out_hidden_states", } @register_tasks_manager_onnx("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers") class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTransformerTimestepInputGenerator, DummyFluxTransformerVisionInputGenerator, DummyFluxTransformerTextInputGenerator, ) @property def inputs(self): common_inputs = super().inputs common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"} common_inputs["txt_ids"] = ( {0: "sequence_length"} if is_diffusers_version(">=", "0.31.0") else {0: "batch_size", 1: "sequence_length"} ) common_inputs["img_ids"] = ( {0: "packed_height_width"} if is_diffusers_version(">=", "0.31.0") else {0: "batch_size", 1: "packed_height_width"} ) if getattr(self._normalized_config, "guidance_embeds", False): common_inputs["guidance"] = {0: "batch_size"} return common_inputs @property def outputs(self): return { "out_hidden_states": {0: "batch_size", 1: "packed_height_width"}, } @register_tasks_manager_onnx("groupvit", *["feature-extraction"]) class GroupViTOnnxConfig(CLIPOnnxConfig): pass @register_tasks_manager_onnx("owlvit", *["feature-extraction", "zero-shot-object-detection"]) class OwlViTOnnxConfig(CLIPOnnxConfig): # Sets the absolute tolerance to when validating the exported ONNX model against the # reference model. ATOL_FOR_VALIDATION = 1e-4 MIN_TORCH_VERSION = version.parse("2.1") # needs einsum operator support, available since opset 12 DEFAULT_ONNX_OPSET = 12 def __init__( self, config: "PretrainedConfig", task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", preprocessors: Optional[List[Any]] = None, legacy: bool = False, ): super().__init__( config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors, legacy=legacy, ) if task == "zero-shot-object-detection": logger.warning( "The batch size of this model will not be dynamic because non-maximum suppression is performed. " "Make sure to export the model with the same batch size as the one you will use at inference " "with `--batch_size N`." ) @property def outputs(self) -> Dict[str, Dict[int, str]]: outputs = {} if self.task == "feature-extraction": outputs["logits_per_image"] = {0: "image_batch_size", 1: "text_batch_size"} outputs["logits_per_text"] = {0: "text_batch_size", 1: "image_batch_size"} elif self.task == "zero-shot-object-detection": outputs["logits"] = {0: "image_batch_size", 2: "num_queries"} outputs["pred_boxes"] = {0: "image_batch_size", 1: "num_boxes"} outputs["text_embeds"] = {0: "text_batch_size", 1: "max_text_queries"} outputs["image_embeds"] = {0: "image_batch_size"} return outputs @register_tasks_manager_onnx("owlv2", *["feature-extraction", "zero-shot-object-detection"]) class OwlV2OnnxConfig(OwlViTOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.35.0") @register_tasks_manager_onnx( "layoutlm", *["feature-extraction", "fill-mask", "text-classification", "token-classification"] ) class LayoutLMOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( allow_new=True, MAX_2D_POSITION_EMBEDDINGS="max_2d_position_embeddings", ) @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "input_ids": {0: "batch_size", 1: "sequence_length"}, "bbox": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, "token_type_ids": {0: "batch_size", 1: "sequence_length"}, } @register_tasks_manager_onnx( "layoutlmv3", *["feature-extraction", "question-answering", "text-classification", "token-classification"] ) class LayoutLMv3OnnxConfig(TextAndVisionOnnxConfig): MIN_TORCH_VERSION = version.parse("1.12") NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( allow_new=True, MAX_2D_POSITION_EMBEDDINGS="max_2d_position_embeddings", image_size="input_size", ) DEFAULT_ONNX_OPSET = 12 @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.task in ["text-classification", "question-answering"]: pixel_values_dynamic_axes = {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"} else: pixel_values_dynamic_axes = {0: "batch_size", 1: "num_channels"} return { "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, "bbox": {0: "batch_size", 1: "sequence_length"}, "pixel_values": pixel_values_dynamic_axes, } @register_tasks_manager_onnx( "lilt", *["feature-extraction", "question-answering", "text-classification", "token-classification"] ) class LiltOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( allow_new=True, MAX_2D_POSITION_EMBEDDINGS="max_2d_position_embeddings", ) @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "input_ids": {0: "batch_size", 1: "sequence_length"}, "bbox": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, } @register_tasks_manager_onnx("data2vec-text", *COMMON_TEXT_TASKS) class Data2VecTextOnnxConfig(DistilBertOnnxConfig): pass @register_tasks_manager_onnx("data2vec-vision", *["feature-extraction", "image-classification"]) class Data2VecVisionOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx( "data2vec-audio", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class Data2VecAudioOnnxConfig(AudioOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedConfig @register_tasks_manager_onnx("perceiver", *["fill-mask", "text-classification", "image-classification"]) class PerceiverOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig DUMMY_INPUT_GENERATOR_CLASSES = ( PerceiverDummyInputGenerator, ) + TextAndVisionOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES def __init__( self, config: "PretrainedConfig", task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", preprocessors: Optional[List[Any]] = None, legacy: bool = False, ): super().__init__( config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors, legacy=legacy, ) self.is_generating_dummy_inputs = False @property def inputs_name(self): if self.is_generating_dummy_inputs: if self.task in ["fill-mask", "text-classification"]: return "input_ids" else: return "pixel_values" else: return "inputs" @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.inputs_name in ["input_ids", "inputs"]: dynamic_axis = {0: "batch_size", 1: "sequence_length"} return { "input_ids": dynamic_axis, "attention_mask": dynamic_axis, } else: dynamic_axis = {0: "batch_size", 1: "sequence_length", 2: "width", 3: "height"} return { "pixel_values": dynamic_axis, } @property def outputs(self) -> Dict[str, Dict[int, str]]: outputs = super().outputs if "logits" in outputs: # default is {0: "batch_size", 1: "sequence_length"} where sequence_length is dynamic axis # but perceiver always return the same max sequence length in the second dimension outputs["logits"] = {0: "batch_size"} return outputs def generate_dummy_inputs(self, framework: str = "pt", **kwargs): self.is_generating_dummy_inputs = True dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) dummy_inputs[self.inputs_name] = dummy_inputs.pop(self.inputs_name) return dummy_inputs @register_tasks_manager_onnx("hubert", *["feature-extraction", "automatic-speech-recognition", "audio-classification"]) class HubertOnnxConfig(AudioOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx( "wav2vec2", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class Wav2Vec2OnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx( "wav2vec2-conformer", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class Wav2Vec2ConformerOnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 11 @register_tasks_manager_onnx("sew", *["feature-extraction", "automatic-speech-recognition", "audio-classification"]) class SEWOnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx("sew-d", *["feature-extraction", "automatic-speech-recognition", "audio-classification"]) class SEWDOnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 12 @register_tasks_manager_onnx( "unispeech", *["feature-extraction", "automatic-speech-recognition", "audio-classification"] ) class UniSpeechOnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx( "unispeech-sat", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class UniSpeechSATOnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @register_tasks_manager_onnx( "wavlm", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class WavLMOnnxConfig(HubertOnnxConfig): DEFAULT_ONNX_OPSET = 12 # we need to set output_attentions=True in the model input to avoid calling # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export # due to the op torch.nn.functional.multi_head_attention_forward used for WavLM _MODEL_PATCHER = WavLMModelPatcher @register_tasks_manager_onnx("audio-spectrogram-transformer", *["feature-extraction", "audio-classification"]) class ASTOnnxConfig(OnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( num_mel_bins="num_mel_bins", max_length="max_length", allow_new=True ) DUMMY_INPUT_GENERATOR_CLASSES = (ASTDummyAudioInputGenerator,) ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"input_values": {0: "batch_size"}} @register_tasks_manager_onnx("mctct", *["feature-extraction", "automatic-speech-recognition"]) class MCTCTOnnxConfig(OnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( input_features_per_channel="input_feat_per_channel", allow_new=True ) DUMMY_INPUT_GENERATOR_CLASSES = (MCTCTDummyAudioInputGenerator,) DEFAULT_ONNX_OPSET = 13 @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"input_features": {0: "batch_size", 1: "sequence_classification"}} @register_tasks_manager_onnx( "moonshine", *[ "feature-extraction", "feature-extraction-with-past", "automatic-speech-recognition", "automatic-speech-recognition-with-past", ], ) class MoonshineOnnxConfig(AudioToTextOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported. # Support for this operator was added in version 14, try exporting with this version. DEFAULT_ONNX_OPSET = 14 @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} if self._behavior is not ConfigBehavior.DECODER: common_inputs["input_values"] = {0: "batch_size", 1: "num_samples"} if self._behavior is not ConfigBehavior.ENCODER: if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} self.add_past_key_values(common_inputs, direction="inputs") else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} if self._behavior is ConfigBehavior.DECODER: common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} return common_inputs @register_tasks_manager_onnx( "whisper", *[ "feature-extraction", "feature-extraction-with-past", "audio-classification", "automatic-speech-recognition", "automatic-speech-recognition-with-past", ], ) class WhisperOnnxConfig(AudioToTextOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", feature_size="num_mel_bins", allow_new=True, ) ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.task == "audio-classification": common_inputs = {"input_features": {0: "batch_size"}} else: common_inputs = super().inputs if self._behavior is not ConfigBehavior.DECODER: common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis. if is_transformers_version(">=", "4.43.0") and is_transformers_version("<", "4.46.0"): # since https://github.com/huggingface/transformers/pull/31166 if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs: common_inputs["cache_position"] = {0: "decoder_sequence_length"} if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs: common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self._behavior is ConfigBehavior.ENCODER: # For Whisper, we need to name the second axis as encoder_sequence_length / 2 as the axis name is used for # dummy input generation common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2" return common_outputs @register_tasks_manager_onnx("musicgen", *["text-to-audio"]) class MusicgenOnnxConfig(OnnxSeq2SeqConfigWithPast): # NOTE: Several warnings during the export are not to worry about: # * for i, indices in enumerate(codes): --> can be unrolled, fixed length (num_quantizers). # * max_pad = max(padding_left, padding_right) --> does not impact later controlflows. # if length <= max_pad: --> appears to be always False for Musicgen. # opset>=13 needed to avoid a bug in T5 encoder SelfAttention. # opset>=14 needed for torch.tril export. DEFAULT_ONNX_OPSET = 14 VARIANTS = { "text-conditional-with-past": """Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported: * text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457. * encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480. * decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention. Not required at inference (use decoder_model_merged.onnx instead). * decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention. Not required at inference (use decoder_model_merged.onnx instead). * decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False. * build_delay_pattern_mask.onnx: A model taking as input `input_ids`, `pad_token_id`, `max_length`, and building a delayed pattern mask to the input_ids. Implements https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/musicgen/modeling_musicgen.py#L1054.""", } # TODO: support audio-prompted generation (audio_encoder_encode.onnx: corresponds to the audio encoder part # in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.) # With that, we have full Encodec support. DEFAULT_VARIANT = "text-conditional-with-past" NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTextInputGenerator, DummyCodegenDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummyEncodecInputGenerator, DummyIntGenerator, ) DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator _MODEL_PATCHER = MusicgenModelPatcher def __init__( self, config: "PretrainedConfig", task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.ENCODER, preprocessors: Optional[List[Any]] = None, model_part: Optional[Literal["text_encoder", "encodec_decode", "decoder", "build_delay_pattern_mask"]] = None, legacy: bool = False, variant: str = "text-conditional-with-past", ): super().__init__( config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, use_past=use_past, use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=preprocessors, legacy=legacy, ) if legacy: raise ValueError("Musicgen does not support legacy=True.") if ( model_part in ["text_encoder", "encodec_decode", "build_delay_pattern_mask"] and behavior != ConfigBehavior.ENCODER ): raise ValueError( f"model_part is {model_part} and behavior is {behavior}. This is not supported, please open an issue at https://github.com/huggingface/optimum/issues." ) if model_part == "decoder" and behavior != ConfigBehavior.DECODER: raise ValueError( f"model_part is {model_part} and behavior is {behavior}. This is not supported, please open an issue at https://github.com/huggingface/optimum/issues." ) if behavior == ConfigBehavior.MONOLITH: raise ValueError( "Musicgen does not support behavior=ConfigBehavior.MONOLITH. Please open an issue at https://github.com/huggingface/optimum/issues." ) if config.audio_encoder.model_type != "encodec": raise ValueError( f"Optimum ONNX export for Musicgen supports only Encodec as the audio encoder, got: {config.audio_encoder.model_type}. Please open an issue at https://github.com/huggingface/optimum/issues." ) # Handling it would require to trace the audio_encoder.decode with torch.jit.script as we than have an unrollable loop. if config.audio_encoder.chunk_length_s is not None: raise ValueError( f"Musicgen ONNX export currently does not support audio_encoder.chunk_length_s not None (got {config.audio_encoder.chunk_length_s}). Please open an issue at https://github.com/huggingface/optimum/issues." ) self.model_part = model_part if self.model_part == "decoder": self.use_past = True # without past is not supported, hard-code it here. self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = NormalizedTextConfig(self._config.text_encoder) self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = NormalizedConfig(self._config.decoder) self._normalized_config.decoder_num_layers = self._config.decoder.num_hidden_layers self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_layers = self._config.decoder.num_hidden_layers self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.encoder_num_attention_heads = ( self._config.decoder.num_attention_heads ) self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.decoder_num_attention_heads = ( self._config.decoder.num_attention_heads ) @property def inputs(self) -> Dict[str, Dict[int, str]]: # Batched inference is not supported in Transformers. if self.model_part == "text_encoder": common_inputs = { "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } elif self.model_part == "encodec_decode": # 0: always 1 for chunk_length_s=None, 2: num_quantizers fixed. common_inputs = {"audio_codes": {1: "batch_size", 3: "chunk_length"}} elif self.model_part == "build_delay_pattern_mask": common_inputs = { "input_ids": {0: "batch_size_x_num_codebooks"}, "pad_token_id": {}, "max_length": {}, } elif self._behavior is ConfigBehavior.DECODER: # Naming it total_batch_size as in case we use guidance_scale, the dimension 0 may be larger than simply the batch_size. # Reference: https://github.com/huggingface/transformers/blob/31c575bcf13c2b85b65d652dd1b5b401f99be999/src/transformers/models/musicgen/modeling_musicgen.py#L1932-L1935 common_inputs = { "decoder_input_ids": {0: "total_batch_size_x_num_codebooks"}, "encoder_outputs": {0: "total_batch_size", 1: "encoder_sequence_length"}, # MusicgenForConditionalGeneration maps attention_mask to encoder_attention_mask. "attention_mask": { 0: "batch_size", 1: "encoder_sequence_length", }, } if self.use_past_in_inputs: # TODO: validate the axis name for attention_mask # common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length" self.add_past_key_values(common_inputs, direction="inputs") else: common_inputs["decoder_input_ids"] = { 0: "total_batch_size_x_num_codebooks", 1: "decoder_sequence_length", } else: raise ValueError( "This should not happen. Please open an issue at https://github.com/huggingface/optimum/issues." ) return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = {} if self.model_part == "text_encoder": common_outputs = super().outputs elif self.model_part == "encodec_decode": common_outputs["audio_values"] = {0: "batch_size", 2: "audio_length"} elif self.model_part == "build_delay_pattern_mask": common_outputs["input_ids_edited"] = {0: "total_batch_size_x_num_codebooks"} common_outputs["delay_pattern_mask"] = {0: "total_batch_size_x_num_codebooks", 1: "max_length"} elif self._behavior is ConfigBehavior.DECODER: common_outputs = super().outputs # MusicgenForConditionalGeneration output is named logits, not last_hidden_state. # Rename last_hidden_state -> logits while keeping the order. common_outputs = { "logits" if name == "last_hidden_state" else name: value for name, value in common_outputs.items() } else: raise ValueError( "This should not happen. Please open an issue at https://github.com/huggingface/optimum/issues." ) return common_outputs def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): if direction not in ["inputs", "outputs"]: raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') if direction == "inputs": decoder_sequence_name = "past_decoder_sequence_length" name = "past_key_values" else: decoder_sequence_name = "past_decoder_sequence_length + 1" name = "present" for i in range(self._normalized_config.decoder_num_layers): inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "total_batch_size", 2: decoder_sequence_name} inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "total_batch_size", 2: decoder_sequence_name} if ( self.is_merged is True or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) or direction == "inputs" ): # TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export() # time we have currently no case to check whether we will merge at a later step or not (self.is_merged is # not yet set at this time) inputs_or_outputs[f"{name}.{i}.encoder.key"] = { 0: "total_batch_size", 2: "encoder_sequence_length_out", } inputs_or_outputs[f"{name}.{i}.encoder.value"] = { 0: "total_batch_size", 2: "encoder_sequence_length_out", } @property def torch_to_onnx_input_map(self) -> Dict[str, str]: if self._behavior is ConfigBehavior.DECODER: return { "decoder_input_ids": "input_ids", "encoder_outputs": "encoder_hidden_states", "attention_mask": "encoder_attention_mask", } return {} def post_process_exported_models( self, path: Path, models_and_onnx_configs: Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] ], onnx_files_subpaths: List[str], ): # Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task if "with-past" in self.variant: decoder_path = Path(path, onnx_files_subpaths[2]) decoder_with_past_path = Path(path, onnx_files_subpaths[3]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") try: from ...onnx import merge_decoders # The decoder with past does not output the cross attention past key values as they are constant, # hence the need for strict=False merge_decoders( decoder=decoder_path, decoder_with_past=decoder_with_past_path, save_path=decoder_merged_path, strict=False, ) except Exception as e: raise Exception(f"Unable to merge decoders. Detailed error: {e}") # In order to do the validation of the two branches on the same file text_encoder_path = onnx_files_subpaths[0] encodec_decode_path = onnx_files_subpaths[1] build_delay_pattern_mask_path = onnx_files_subpaths[4] onnx_files_subpaths_new = [ text_encoder_path, encodec_decode_path, decoder_merged_path.name, decoder_merged_path.name, build_delay_pattern_mask_path, ] # We validate the two branches of the decoder model then models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False # Past key values won't be generated by default, but added in the input models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True else: onnx_files_subpaths_new = onnx_files_subpaths return models_and_onnx_configs, onnx_files_subpaths_new def overwrite_shape_and_generate_input( self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict ): if self.model_part == "build_delay_pattern_mask" and input_name == "input_ids": original_batch_size = dummy_input_gen.batch_size dummy_input_gen.batch_size = ( original_batch_size * dummy_input_gen.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_codebooks ) dummy_input = dummy_input_gen.generate( input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype ) dummy_input_gen.batch_size = original_batch_size else: dummy_input = super().overwrite_shape_and_generate_input( dummy_input_gen, input_name, framework, input_shapes ) return dummy_input @register_tasks_manager_onnx("speecht5", *["text-to-audio"]) class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): # TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943), # so we won't support for now. NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( hidden_size="hidden_size", num_attention_heads="encoder_attention_heads", # TODO: bugged in case encoder and decoder have different number of heads encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", allow_new=True, ) DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummySpeechT5InputGenerator, ) DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator VARIANTS = { "with-past": "The export follows the Transformers implementation using the KV cache, with the following components exported:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.", "without-past": "The same as `with-past`, just without KV cache support. This is not a recommended export as slower than `with-past`.", } DEFAULT_VARIANT = "with-past" _MODEL_PATCHER = SpeechT5ModelPatcher def __init__( self, config: "PretrainedConfig", task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, is_postnet_and_vocoder: bool = False, legacy: bool = False, ): super().__init__( config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, use_past=use_past, use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=preprocessors, legacy=legacy, ) if float_dtype == "fp16": raise ValueError( "The ONNX export of SpeechT5 in float16 is currently not supported due to a bug in PyTorch: https://github.com/pytorch/pytorch/pull/110078. Please open an issue in Optimum if you would like to export SpeechT5 in float16." ) self.is_postnet_and_vocoder = is_postnet_and_vocoder @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} # Batched inference is not supported in Transformers. if self._behavior is ConfigBehavior.ENCODER: common_inputs["input_ids"] = {1: "encoder_sequence_length"} elif self._behavior is ConfigBehavior.DECODER: # NOTE: even when past is used, the decoder takes the full sequence as input as the prenet seem to require it: # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2573 common_inputs["output_sequence"] = {1: "decoder_sequence_length"} common_inputs["speaker_embeddings"] = {} # No dynamic shape here. common_inputs["encoder_outputs"] = {1: "encoder_sequence_length"} common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} if self.variant == "with-past" and self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") elif self.is_postnet_and_vocoder: common_inputs["spectrogram"] = {0: "n_spectrums x reduction_factor"} else: raise ValueError( "self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen." ) return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = {} if self._behavior is ConfigBehavior.ENCODER: common_outputs["encoder_outputs"] = {1: "encoder_sequence_length"} common_outputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} elif self._behavior is ConfigBehavior.DECODER: common_outputs["output_sequence_out"] = {1: "decoder_sequence_length + 1"} common_outputs["spectrum"] = {} # No dynamic shape here. common_outputs["prob"] = {} # No dynamic shape here. if self.variant == "with-past" and self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") elif self.is_postnet_and_vocoder: common_outputs["waveform"] = {0: "n_samples"} else: raise ValueError( "self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen." ) return common_outputs @property def torch_to_onnx_input_map(self) -> Dict[str, str]: return {"encoder_outputs": "encoder_hidden_states"} def overwrite_shape_and_generate_input( self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict ): dummy_input_gen.batch_size = 1 dummy_input = dummy_input_gen.generate( input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype ) return dummy_input def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): if direction not in ["inputs", "outputs"]: raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') if direction == "inputs": decoder_sequence_name = "past_decoder_sequence_length" name = "past_key_values" else: decoder_sequence_name = "past_decoder_sequence_length + 1" name = "present" for i in range(self._normalized_config.decoder_num_layers): inputs_or_outputs[f"{name}.{i}.decoder.key"] = {2: decoder_sequence_name} inputs_or_outputs[f"{name}.{i}.decoder.value"] = {2: decoder_sequence_name} if ( self.is_merged is True or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) or direction == "inputs" ): inputs_or_outputs[f"{name}.{i}.encoder.key"] = {2: "encoder_sequence_length_out"} inputs_or_outputs[f"{name}.{i}.encoder.value"] = {2: "encoder_sequence_length_out"} @register_tasks_manager_onnx("vits", *["text-to-audio"]) class VitsOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig ATOL_FOR_VALIDATION = 1e-4 @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "input_ids": {0: "text_batch_size", 1: "sequence_length"}, "attention_mask": {0: "text_batch_size", 1: "sequence_length"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "waveform": {0: "text_batch_size", 1: "n_samples"}, "spectrogram": {0: "text_batch_size", 2: "num_bins"}, } @register_tasks_manager_onnx( "speech_to_text", *[ "feature-extraction", "feature-extraction-with-past", "automatic-speech-recognition", "automatic-speech-recognition-with-past", ], ) class Speech2TextOnnxConfig(AudioToTextOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( decoder_num_layers="decoder_layers", num_layers="decoder_layers", input_features_per_channel="input_feat_per_channel", allow_new=True, ) DUMMY_INPUT_GENERATOR_CLASSES = ( (Speech2TextDummyAudioInputGenerator,) + AudioToTextOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:] + (DummyTextInputGenerator,) ) ATOL_FOR_VALIDATION = 1e-4 @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} if self._behavior is not ConfigBehavior.DECODER: common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"} common_inputs["attention_mask"] = {0: "batch_size", 1: "encoder_sequence_length"} if self._behavior is not ConfigBehavior.ENCODER: if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") if self._behavior is ConfigBehavior.DECODER: common_inputs["encoder_outputs"] = { 0: "batch_size", 1: f"encoder_sequence_length / {(2 * self._config.num_conv_layers)}", } return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self._behavior is ConfigBehavior.ENCODER: # for Speech2text, we need to name the second axis as # encoder_sequence_length / 2 * self._config.num_conv_layers as the axis name is # used for dummy input generation common_outputs["last_hidden_state"][ 1 ] = f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}" return common_outputs # TODO: Replace the TextSeq2SeqOnnxConfig inheritance with VisionToTextOnnxConfig when added. # The change below however does not affect the export for the model @register_tasks_manager_onnx( "trocr", *["feature-extraction", "feature-extraction-with-past", "image-to-text", "image-to-text-with-past"] ) class TrOCROnnxConfig(TextSeq2SeqOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( decoder_num_layers="decoder_layers", num_layers="decoder_layers", decoder_num_attention_heads="decoder_attention_heads", hidden_size="hidden_size", ) @register_tasks_manager_onnx( "donut", *[ "image-to-text", "image-to-text-with-past", "document-question-answering", "document-question-answering-with-past", ], ) @register_tasks_manager_onnx( "vision-encoder-decoder", *[ "image-to-text", "image-to-text-with-past", "document-question-answering", "document-question-answering-with-past", ], ) class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig ATOL_FOR_VALIDATION = 1e-3 DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator) _MODEL_PATCHER = VisionEncoderDecoderPatcher @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} if self._behavior is not ConfigBehavior.DECODER: common_inputs["pixel_values"] = {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"} if self._behavior is not ConfigBehavior.ENCODER: if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") if self._behavior is ConfigBehavior.DECODER: common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: if self._behavior == ConfigBehavior.ENCODER: # Some encoders have static sequence length so it is useful to rely on the encoder ONNX config to grab this information. return self._encoder_onnx_config.outputs else: # Ideally, we would want here to have self._decoder_onnx_config.outputs, which is currently not possible # as we hard-code the task to feature-extraction, that has the wrong output names (e.g. mbart does not support document-question-answering # so we can not initializer MBartONNXConfig with document-question-answering). return super().outputs @register_tasks_manager_onnx("sam", *["feature-extraction"]) class SamOnnxConfig(OnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0") # Since ransformers 4.32.0, SAM uses repeat_interleave op that is broken in PyTorch 2.0.1: https://github.com/pytorch/pytorch/issues/100429 MIN_TORCH_VERSION = version.parse("2.0.99") NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator, DummyVisionEmbeddingsGenerator) DEFAULT_ONNX_OPSET = 13 # Opset 12 for repeat_interleave falls back on the opset 9 implem, that raises Unsupported: ONNX export of repeat_interleave in opset 9. VARIANTS = { "monolith": "All the SAM model components are exported as a single model.onnx.", "split": "The vision encoder is exported as a separate vision_encoder.onnx, and the prompt encoder and mask decoder are exported as a prompt_encoder_mask_decoder.onnx. This allows to encoder the image only once for multiple point queries.", } DEFAULT_VARIANT = "split" _MODEL_PATCHER = SAMModelPatcher def __init__( self, config: "PretrainedConfig", task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", variant: str = "split", vision_encoder: Optional[bool] = None, preprocessors: Optional[List[Any]] = None, legacy: bool = False, ): super().__init__( config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors, legacy=legacy, ) self.variant = variant self.vision_encoder = vision_encoder self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig(self._config.vision_config) @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.variant == "monolith": inputs = { "pixel_values": {0: "batch_size"}, "input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, "input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, } else: if self.vision_encoder: inputs = {"pixel_values": {0: "batch_size"}} else: inputs = { "image_positional_embeddings": {0: "batch_size"}, "image_embeddings": {0: "batch_size"}, "input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, "input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, } return inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: if self.variant == "split" and self.vision_encoder: return {"image_embeddings": {0: "batch_size"}, "image_positional_embeddings": {0: "batch_size"}} else: return { "iou_scores": {0: "batch_size", 1: "point_batch_size"}, "pred_masks": {0: "batch_size", 1: "point_batch_size"}, } class Pix2StructNormalizedConfig(NormalizedSeq2SeqConfig): ENCODER_NUM_LAYERS = "vision_config.num_hidden_layers" DECODER_NUM_LAYERS = "text_config.num_layers" ENCODER_NUM_ATTENTION_HEADS = "vision_config.num_attention_heads" DECODER_NUM_ATTENTION_HEADS = "text_config.num_heads" HIDDEN_SIZE = "text_config.hidden_size" # TODO: Isn't this bug prone? VOCAB_SIZE = "text_config.vocab_size" @register_tasks_manager_onnx( "pix2struct", *["image-to-text", "image-to-text-with-past", "visual-question-answering", "visual-question-answering-with-past"], ) class Pix2StructOnnxConfig(OnnxSeq2SeqConfigWithPast): NORMALIZED_CONFIG_CLASS = Pix2StructNormalizedConfig DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummyPix2StructInputGenerator, ) DEFAULT_ONNX_OPSET = 14 # use 'aten::triu' now which is opset 14 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if is_transformers_version("==", "4.46.0") and self._behavior is ConfigBehavior.DECODER: logger.error( "Found transformers v4.46.0 while trying to exporting a Pix2Struct model, this specific version of transformers is not supported. " "Please upgrade to v4.46.1 or higher, or downgrade your transformers version" ) @property def inputs(self): common_inputs = {} common_inputs["attention_mask"] = {0: "batch_size"} if self._behavior is not ConfigBehavior.DECODER: common_inputs["flattened_patches"] = {0: "batch_size"} if self._behavior is not ConfigBehavior.ENCODER: if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} if self._behavior is ConfigBehavior.DECODER: if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") common_inputs["encoder_outputs"] = {0: "batch_size"} # Contrary to other seq2seq archs as t5 and bart, Pix2Struct DO make use of the decoder_attention_mask input. common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} return common_inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: if self._behavior is ConfigBehavior.ENCODER: common_outputs = { "last_hidden_state": {0: "batch_size"} } # The last hidden state dim=1 is constant, no need for it to be dynamic. else: common_outputs = super(OnnxConfigWithPast, self).outputs # Renaming the outputs axes properly. for name, axes_names in common_outputs.items(): if self._behavior is ConfigBehavior.ENCODER or "encoder" in name: sequence_name = "encoder_sequence_length" else: sequence_name = "decoder_sequence_length" new_axes_names = {} for axis_idx, axis_name in axes_names.items(): if "sequence" in axis_name: if self.use_past_in_inputs is False or self.is_merged is True: new_axes_names[axis_idx] = sequence_name else: # Trick to force it since ONNX sometimes infer a dynamic axis where it's not. new_axes_names[axis_idx] = "1" else: new_axes_names[axis_idx] = axis_name common_outputs[name] = new_axes_names if self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @property def torch_to_onnx_input_map(self) -> Dict[str, str]: if self._behavior is ConfigBehavior.DECODER: return { "decoder_input_ids": "input_ids", "encoder_outputs": "encoder_hidden_states", "attention_mask": "encoder_attention_mask", } return {} def generate_dummy_inputs_for_validation( self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None ) -> Dict[str, Any]: if self._behavior is ConfigBehavior.DECODER: reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids") if onnx_input_names is not None: if "encoder_outputs" in reference_model_inputs: if "encoder_hidden_states" in onnx_input_names: reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] else: reference_model_inputs.pop("encoder_outputs") else: # TODO: remove this else in optimum 2.0 and make onnx_input_names a required argument # Pix2Struct requires encoder_hidden_states as an input for both the without/with past models, # which is different than other architectures that require it only for the without past case reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] return super().generate_dummy_inputs_for_validation(reference_model_inputs) def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: dummy_inputs_generators = [] dummy_inputs_generators.append(self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config)) if self._preprocessors is None or len(self._preprocessors) < 2: raise ValueError( f"Preprocessors for pix2struct need to be available for the ONNX export to infer input static shapes. Got: {self._preprocessors}" ) encoder_sequence_length = self._preprocessors[1].image_processor.max_patches # A hack for DummyPix2StructInputGenerator to gain access to the preprocessors. # TODO: we should probably pass preprocessors to all dummy input generators. kwargs["preprocessors"] = self._preprocessors for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:]: dummy_inputs_generators.append( cls_(self.task, self._normalized_config, encoder_sequence_length=encoder_sequence_length, **kwargs) ) return dummy_inputs_generators def overwrite_shape_and_generate_input( self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict ): if self._preprocessors is None or len(self._preprocessors) < 2: raise ValueError( f"Preprocessors for pix2struct need to be available for the ONNX export to infer input static shapes. Got: {self._preprocessors}" ) # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name # while models from TextDecoderOnnxConfig use input_ids, hence the check for both if ( self.use_past and self.use_past_in_inputs and self.use_cache_branch is not False and input_name in ["decoder_input_ids", "input_ids"] ): sequence_length = dummy_input_gen.sequence_length # Use a sequence length of 1 when the KV cache is already populated. dummy_input_gen.sequence_length = 1 dummy_input = dummy_input_gen.generate( input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype ) dummy_input_gen.sequence_length = sequence_length elif input_name in ["encoder_outputs", "attention_mask"]: # pix2struct takes inputs whose so-called sequence length is **static** to max_patches, so we do NOT use # the passed sequence_length that behaves as a dynamic shape. original_seq_length = dummy_input_gen.sequence_length dummy_input_gen.sequence_length = self._preprocessors[1].image_processor.max_patches dummy_input = dummy_input_gen.generate( input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype ) dummy_input_gen.sequence_length = original_seq_length else: dummy_input = dummy_input_gen.generate( input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype ) return dummy_input @register_tasks_manager_onnx("encoder-decoder", *["text2text-generation", "text2text-generation-with-past"]) class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. @register_tasks_manager_onnx("patchtst", *["feature-extraction", "time-series-forecasting"]) class PatchTSTOnnxConfig(OnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig DUMMY_INPUT_GENERATOR_CLASSES = (DummyPatchTSTInputGenerator,) ATOL_FOR_VALIDATION = 1e-4 @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"past_values": {0: "batch_size", 1: "sequence_length"}} @property def outputs(self) -> Dict[str, Dict[int, str]]: if self.task == "feature-extraction": return {"last_hidden_state": {0: "batch_size"}} else: return super().outputs @register_tasks_manager_onnx("patchtsmixer", *["feature-extraction", "time-series-forecasting"]) class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig): pass @register_tasks_manager_onnx("rt_detr", *["object-detection"]) class RTDetrOnnxConfig(ViTOnnxConfig): # Export the operator 'aten::grid_sampler' to ONNX fails under opset 16. # Support for this operator was added in version 16. DEFAULT_ONNX_OPSET = 16 ATOL_FOR_VALIDATION = 1e-5 @property def inputs(self) -> Dict[str, Dict[int, str]]: return { "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, } def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: min_image_size = int(math.ceil(self._config.num_queries / 32) * 32) if kwargs["height"] < min_image_size: warnings.warn( f"Exporting model with image `height={kwargs['height']}` which is less than " f"minimal {min_image_size}, setting `height` to {min_image_size}." ) kwargs["height"] = min_image_size if kwargs["width"] < min_image_size: warnings.warn( f"Exporting model with image `width={kwargs['width']}` which is less than " f"minimal {min_image_size}, setting `width` to {min_image_size}." ) kwargs["width"] = min_image_size return super()._create_dummy_input_generator_classes(**kwargs) @register_tasks_manager_onnx("rt_detr_v2", *["object-detection"]) class RTDetrV2OnnxConfig(RTDetrOnnxConfig): pass @register_tasks_manager_onnx("colpali", *["feature-extraction"]) class ColPaliOnnxConfig(GemmaOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator) NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig.with_args( allow_new=True, text_config="text_config", vision_config="vlm_config.vision_config", vlm_config="vlm_config", ) ATOL_FOR_VALIDATION = 1e-4 VARIANTS = { "vision": "Embedding extraction for image.", "text": "Embedding extraction for text.", } DEFAULT_VARIANT = "vision" @property def inputs(self) -> Dict[str, Dict[int, str]]: dynamic_axis = {0: "batch_size", 1: "sequence_length"} if self.variant == "vision": return { "input_ids": dynamic_axis, "attention_mask": dynamic_axis, "pixel_values": {0: "batch_size"}, } else: return { "input_ids": dynamic_axis, "attention_mask": dynamic_axis, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { "embeddings": {0: "batch_size", 1: "sequence_length"}, } def generate_dummy_inputs(self, framework: str = "pt", **kwargs): if self.variant == "vision": image_token_index = self._normalized_config.vlm_config.image_token_index num_image_tokens = self._normalized_config.vision_config.num_image_tokens if "sequence_length" in kwargs: kwargs["sequence_length"] += num_image_tokens else: kwargs["sequence_length"] = DEFAULT_DUMMY_SHAPES["sequence_length"] + num_image_tokens dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) if self.variant == "vision": dummy_inputs["input_ids"][:, :num_image_tokens] = image_token_index return dummy_inputs @register_tasks_manager_onnx("d_fine", *["object-detection"]) class DFineOnnxConfig(RTDetrOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.52.0")