optimum/utils/normalized_config.py (205 lines of code) (raw):

# coding=utf-8 # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Normalization configuration classes.""" import functools from typing import TYPE_CHECKING, Callable, Dict, Type, Union if TYPE_CHECKING: from transformers import PretrainedConfig class NormalizedConfig: """ Handles the normalization of [`PretrainedConfig`] attribute names, allowing to access attributes in a general way. Attributes: config ([`PretrainedConfig`]): The config to normalize. """ def __init__(self, config: Union["PretrainedConfig", Dict], allow_new: bool = False, **kwargs): self.config = config for key, value in kwargs.items(): if allow_new or hasattr(self, key.upper()): setattr(self, key.upper(), value) else: raise AttributeError( f"{self.__class__} has not attribute {key}. Set allow_new=True to add a new attribute." ) @classmethod def with_args(cls, allow_new: bool = False, **kwargs) -> Callable[["PretrainedConfig"], "NormalizedConfig"]: return functools.partial(cls, allow_new=allow_new, **kwargs) def __getattr__(self, attr_name): if attr_name == "config": return super().__getattr__(attr_name) try: attr_name = super().__getattribute__(attr_name.upper()) except AttributeError: # e.g. in the NormalizedTextAndVisionConfig case pass attr_name = attr_name.split(".") leaf_attr_name = attr_name[-1] config = self.config for attr in attr_name[:-1]: config = getattr(config, attr) attr = getattr(config, leaf_attr_name, None) # If the attribute was not specified manually, try to fallback on the attribute_map. if attr is None: attribute_map = getattr(self.config, "attribute_map", {}) attr = getattr(self.config, attribute_map.get(leaf_attr_name, ""), None) if attr is None: raise AttributeError(f'Could not find the attribute named "{leaf_attr_name}" in the normalized config.') return attr def has_attribute(self, attr_name): try: self.__getattr__(attr_name) except AttributeError: return False return True class NormalizedTimeSeriesForecastingConfig(NormalizedConfig): NUM_INPUT_CHANNELS = "num_input_channels" CONTEXT_LENGTH = "context_length" class NormalizedTextConfig(NormalizedConfig): VOCAB_SIZE = "vocab_size" HIDDEN_SIZE = "hidden_size" NUM_LAYERS = "num_hidden_layers" NUM_ATTENTION_HEADS = "num_attention_heads" EOS_TOKEN_ID = "eos_token_id" class NormalizedTextConfigWithGQA(NormalizedTextConfig): NUM_KEY_VALUE_HEADS = "num_key_value_heads" class NormalizedSeq2SeqConfig(NormalizedTextConfig): ENCODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS DECODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS ENCODER_NUM_ATTENTION_HEADS = NormalizedTextConfig.NUM_ATTENTION_HEADS DECODER_NUM_ATTENTION_HEADS = NormalizedTextConfig.NUM_ATTENTION_HEADS class NormalizedVisionConfig(NormalizedConfig): IMAGE_SIZE = "image_size" NUM_CHANNELS = "num_channels" INPUT_SIZE = "input_size" class NormalizedSegformerConfig(NormalizedVisionConfig): NUM_ATTENTION_HEADS = "num_attention_heads" HIDDEN_SIZE = "hidden_sizes" # If the attribute is a list, return 0 # 0 means let the optimizer infer the correct value based on the model graph def __getattr__(self, attr_name): attr_value = super().__getattr__(attr_name) if isinstance(attr_value, list): attr_value = 0 return attr_value class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig): TEXT_CONFIG = None VISION_CONFIG = None def __getattr__(self, attr_name): if self.TEXT_CONFIG is not None and attr_name.upper() in dir(NormalizedTextConfig): attr_name = f"{self.TEXT_CONFIG}.{attr_name}" elif self.VISION_CONFIG is not None and attr_name.upper() in dir(NormalizedVisionConfig): attr_name = f"{self.VISION_CONFIG}.{attr_name}" return super().__getattr__(attr_name) Pix2StructNormalizedTextConfig = NormalizedTextAndVisionConfig.with_args( text_config="text_config", vision_config="vision_config" ) class NormalizedEncoderDecoderConfig(NormalizedConfig): ENCODER_NORMALIZED_CONFIG_CLASS = None DECODER_NORMALIZED_CONFIG_CLASS = None def __getattr__(self, attr_name): if self.ENCODER_NORMALIZED_CONFIG_CLASS is not None and attr_name.upper() in dir( self.ENCODER_NORMALIZED_CONFIG_CLASS ): return self.ENCODER_NORMALIZED_CONFIG_CLASS.__getattr__(attr_name) if self.DECODER_NORMALIZED_CONFIG_CLASS is not None and attr_name.upper() in dir( self.DECODER_NORMALIZED_CONFIG_CLASS ): return self.DECODER_NORMALIZED_CONFIG_CLASS.__getattr__(attr_name) return super().__getattr__(attr_name) # TODO: this config is bug prone, as `encoder_attention_heads` and `decoder_attention_heads` may be different BartLikeNormalizedTextConfig = NormalizedTextConfig.with_args( num_attention_heads="encoder_attention_heads", hidden_size="d_model", ) GPT2LikeNormalizedTextConfig = NormalizedTextConfig.with_args(num_attention_heads="n_head", hidden_size="n_embd") T5LikeNormalizedTextConfig = NormalizedTextConfig.with_args( num_attention_heads="num_heads", hidden_size="d_model", ) MPTNormalizedTextConfig = NormalizedTextConfig.with_args( num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) GPTBigCodeNormalizedTextConfig = NormalizedTextConfig.with_args( num_attention_heads="n_head", hidden_size="n_embd", num_layers="n_layer" ) WhisperLikeNormalizedTextConfig = NormalizedTextConfig.with_args( hidden_size="d_model", ) TrOCRLikeNormalizedTextConfig = NormalizedTextConfig.with_args( num_layers="decoder_layers", num_attention_heads="decoder_attention_heads", hidden_size="hidden_size", ) SpeechToTextLikeNormalizedTextConfig = NormalizedSeq2SeqConfig.with_args( decoder_num_layers="decoder_layers", num_layers="decoder_layers", input_features_per_channel="input_feat_per_channel", allow_new=True, ) class NormalizedConfigManager: """ A class that contains all the information needed by ONNX Runtime optimization for a given model type. Attributes: _conf (`Dict[str, tuple]`): A dictionary mapping each supported model type to a tuple containing the number of attention heads and the hidden size model config attribute names as well as the corresponding ONNX Runtime model type. """ """ TODO: missing normalized configs (currently not useful) ['beit', 'clip', 'convbert', 'convnext', 'convnextv2', 'data2vec-text', 'data2vec-vision', 'detr', 'flaubert', 'groupvit', 'hiera', 'ibert', 'layoutlm', 'layoutlmv3', 'levit', 'mobilebert', 'mobilevit', 'owlv2', 'owlvit', 'perceiver', 'roformer', 'segformer', 'siglip', 'squeezebert', 'table-transformer', """ # Contribution note: Please add new models in alphabetical order _conf = { "albert": NormalizedTextConfig, "bart": BartLikeNormalizedTextConfig, "bert": NormalizedTextConfig, "big_bird": NormalizedTextConfig, "bigbird_pegasus": BartLikeNormalizedTextConfig, "blenderbot": BartLikeNormalizedTextConfig, "blenderbot-small": BartLikeNormalizedTextConfig, "bloom": NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head"), "falcon": NormalizedTextConfig, "camembert": NormalizedTextConfig, "codegen": GPT2LikeNormalizedTextConfig, "cvt": NormalizedVisionConfig, "deberta": NormalizedTextConfig, "deberta-v2": NormalizedTextConfig, "deit": NormalizedVisionConfig, "dinov2": NormalizedVisionConfig, "distilbert": NormalizedTextConfig.with_args(num_attention_heads="n_heads", hidden_size="dim"), "donut-swin": NormalizedVisionConfig, "electra": NormalizedTextConfig, "encoder-decoder": NormalizedEncoderDecoderConfig, "gemma": NormalizedTextConfigWithGQA, "gpt2": GPT2LikeNormalizedTextConfig, "gpt_bigcode": GPTBigCodeNormalizedTextConfig, "gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), "gpt_neox": NormalizedTextConfig, "gptj": GPT2LikeNormalizedTextConfig, "imagegpt": GPT2LikeNormalizedTextConfig, "internlm2": NormalizedTextConfigWithGQA, "llama": NormalizedTextConfigWithGQA, "longt5": T5LikeNormalizedTextConfig, "marian": BartLikeNormalizedTextConfig, "markuplm": NormalizedTextConfig, "mbart": BartLikeNormalizedTextConfig, "mistral": NormalizedTextConfigWithGQA, "mixtral": NormalizedTextConfigWithGQA, "modernbert": NormalizedTextConfig, "mpnet": NormalizedTextConfig, "mpt": MPTNormalizedTextConfig, "mt5": T5LikeNormalizedTextConfig, "m2m_100": BartLikeNormalizedTextConfig, "nystromformer": NormalizedTextConfig, "olmo": NormalizedTextConfig, "olmo2": NormalizedTextConfig, "opt": NormalizedTextConfig, "pegasus": BartLikeNormalizedTextConfig, "pix2struct": Pix2StructNormalizedTextConfig, "phi": NormalizedTextConfig, "phi3": NormalizedTextConfigWithGQA, "poolformer": NormalizedVisionConfig, "regnet": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, "segformer": NormalizedSegformerConfig, "speech_to_text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, "trocr": TrOCRLikeNormalizedTextConfig, "vision-encoder-decoder": NormalizedEncoderDecoderConfig, "vit": NormalizedVisionConfig, "whisper": WhisperLikeNormalizedTextConfig, "xlm-roberta": NormalizedTextConfig, "yolos": NormalizedVisionConfig, "qwen2": NormalizedTextConfig, "qwen3": NormalizedTextConfig, "qwen3_moe": NormalizedTextConfig, "granite": NormalizedTextConfigWithGQA, } @classmethod def check_supported_model(cls, model_type: str): if model_type not in cls._conf: model_types = ", ".join(cls._conf.keys()) raise KeyError( f"{model_type} model type is not supported yet in NormalizedConfig. Only {model_types} are supported. " f"If you want to support {model_type} please propose a PR or open up an issue." ) @classmethod def get_normalized_config_class(cls, model_type: str) -> Type: cls.check_supported_model(model_type) return cls._conf[model_type]