optimum/utils/normalized_config.py (205 lines of code) (raw):
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization configuration classes."""
import functools
from typing import TYPE_CHECKING, Callable, Dict, Type, Union
if TYPE_CHECKING:
from transformers import PretrainedConfig
class NormalizedConfig:
"""
Handles the normalization of [`PretrainedConfig`] attribute names, allowing to access attributes in a general way.
Attributes:
config ([`PretrainedConfig`]):
The config to normalize.
"""
def __init__(self, config: Union["PretrainedConfig", Dict], allow_new: bool = False, **kwargs):
self.config = config
for key, value in kwargs.items():
if allow_new or hasattr(self, key.upper()):
setattr(self, key.upper(), value)
else:
raise AttributeError(
f"{self.__class__} has not attribute {key}. Set allow_new=True to add a new attribute."
)
@classmethod
def with_args(cls, allow_new: bool = False, **kwargs) -> Callable[["PretrainedConfig"], "NormalizedConfig"]:
return functools.partial(cls, allow_new=allow_new, **kwargs)
def __getattr__(self, attr_name):
if attr_name == "config":
return super().__getattr__(attr_name)
try:
attr_name = super().__getattribute__(attr_name.upper())
except AttributeError: # e.g. in the NormalizedTextAndVisionConfig case
pass
attr_name = attr_name.split(".")
leaf_attr_name = attr_name[-1]
config = self.config
for attr in attr_name[:-1]:
config = getattr(config, attr)
attr = getattr(config, leaf_attr_name, None)
# If the attribute was not specified manually, try to fallback on the attribute_map.
if attr is None:
attribute_map = getattr(self.config, "attribute_map", {})
attr = getattr(self.config, attribute_map.get(leaf_attr_name, ""), None)
if attr is None:
raise AttributeError(f'Could not find the attribute named "{leaf_attr_name}" in the normalized config.')
return attr
def has_attribute(self, attr_name):
try:
self.__getattr__(attr_name)
except AttributeError:
return False
return True
class NormalizedTimeSeriesForecastingConfig(NormalizedConfig):
NUM_INPUT_CHANNELS = "num_input_channels"
CONTEXT_LENGTH = "context_length"
class NormalizedTextConfig(NormalizedConfig):
VOCAB_SIZE = "vocab_size"
HIDDEN_SIZE = "hidden_size"
NUM_LAYERS = "num_hidden_layers"
NUM_ATTENTION_HEADS = "num_attention_heads"
EOS_TOKEN_ID = "eos_token_id"
class NormalizedTextConfigWithGQA(NormalizedTextConfig):
NUM_KEY_VALUE_HEADS = "num_key_value_heads"
class NormalizedSeq2SeqConfig(NormalizedTextConfig):
ENCODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS
DECODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS
ENCODER_NUM_ATTENTION_HEADS = NormalizedTextConfig.NUM_ATTENTION_HEADS
DECODER_NUM_ATTENTION_HEADS = NormalizedTextConfig.NUM_ATTENTION_HEADS
class NormalizedVisionConfig(NormalizedConfig):
IMAGE_SIZE = "image_size"
NUM_CHANNELS = "num_channels"
INPUT_SIZE = "input_size"
class NormalizedSegformerConfig(NormalizedVisionConfig):
NUM_ATTENTION_HEADS = "num_attention_heads"
HIDDEN_SIZE = "hidden_sizes"
# If the attribute is a list, return 0
# 0 means let the optimizer infer the correct value based on the model graph
def __getattr__(self, attr_name):
attr_value = super().__getattr__(attr_name)
if isinstance(attr_value, list):
attr_value = 0
return attr_value
class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig):
TEXT_CONFIG = None
VISION_CONFIG = None
def __getattr__(self, attr_name):
if self.TEXT_CONFIG is not None and attr_name.upper() in dir(NormalizedTextConfig):
attr_name = f"{self.TEXT_CONFIG}.{attr_name}"
elif self.VISION_CONFIG is not None and attr_name.upper() in dir(NormalizedVisionConfig):
attr_name = f"{self.VISION_CONFIG}.{attr_name}"
return super().__getattr__(attr_name)
Pix2StructNormalizedTextConfig = NormalizedTextAndVisionConfig.with_args(
text_config="text_config", vision_config="vision_config"
)
class NormalizedEncoderDecoderConfig(NormalizedConfig):
ENCODER_NORMALIZED_CONFIG_CLASS = None
DECODER_NORMALIZED_CONFIG_CLASS = None
def __getattr__(self, attr_name):
if self.ENCODER_NORMALIZED_CONFIG_CLASS is not None and attr_name.upper() in dir(
self.ENCODER_NORMALIZED_CONFIG_CLASS
):
return self.ENCODER_NORMALIZED_CONFIG_CLASS.__getattr__(attr_name)
if self.DECODER_NORMALIZED_CONFIG_CLASS is not None and attr_name.upper() in dir(
self.DECODER_NORMALIZED_CONFIG_CLASS
):
return self.DECODER_NORMALIZED_CONFIG_CLASS.__getattr__(attr_name)
return super().__getattr__(attr_name)
# TODO: this config is bug prone, as `encoder_attention_heads` and `decoder_attention_heads` may be different
BartLikeNormalizedTextConfig = NormalizedTextConfig.with_args(
num_attention_heads="encoder_attention_heads",
hidden_size="d_model",
)
GPT2LikeNormalizedTextConfig = NormalizedTextConfig.with_args(num_attention_heads="n_head", hidden_size="n_embd")
T5LikeNormalizedTextConfig = NormalizedTextConfig.with_args(
num_attention_heads="num_heads",
hidden_size="d_model",
)
MPTNormalizedTextConfig = NormalizedTextConfig.with_args(
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)
GPTBigCodeNormalizedTextConfig = NormalizedTextConfig.with_args(
num_attention_heads="n_head", hidden_size="n_embd", num_layers="n_layer"
)
WhisperLikeNormalizedTextConfig = NormalizedTextConfig.with_args(
hidden_size="d_model",
)
TrOCRLikeNormalizedTextConfig = NormalizedTextConfig.with_args(
num_layers="decoder_layers",
num_attention_heads="decoder_attention_heads",
hidden_size="hidden_size",
)
SpeechToTextLikeNormalizedTextConfig = NormalizedSeq2SeqConfig.with_args(
decoder_num_layers="decoder_layers",
num_layers="decoder_layers",
input_features_per_channel="input_feat_per_channel",
allow_new=True,
)
class NormalizedConfigManager:
"""
A class that contains all the information needed by ONNX Runtime optimization for a given model type.
Attributes:
_conf (`Dict[str, tuple]`):
A dictionary mapping each supported model type to a tuple containing the number of attention heads
and the hidden size model config attribute names as well as the corresponding ONNX Runtime model type.
"""
"""
TODO: missing normalized configs (currently not useful)
['beit',
'clip',
'convbert',
'convnext',
'convnextv2',
'data2vec-text',
'data2vec-vision',
'detr',
'flaubert',
'groupvit',
'hiera',
'ibert',
'layoutlm',
'layoutlmv3',
'levit',
'mobilebert',
'mobilevit',
'owlv2',
'owlvit',
'perceiver',
'roformer',
'segformer',
'siglip',
'squeezebert',
'table-transformer',
"""
# Contribution note: Please add new models in alphabetical order
_conf = {
"albert": NormalizedTextConfig,
"bart": BartLikeNormalizedTextConfig,
"bert": NormalizedTextConfig,
"big_bird": NormalizedTextConfig,
"bigbird_pegasus": BartLikeNormalizedTextConfig,
"blenderbot": BartLikeNormalizedTextConfig,
"blenderbot-small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head"),
"falcon": NormalizedTextConfig,
"camembert": NormalizedTextConfig,
"codegen": GPT2LikeNormalizedTextConfig,
"cvt": NormalizedVisionConfig,
"deberta": NormalizedTextConfig,
"deberta-v2": NormalizedTextConfig,
"deit": NormalizedVisionConfig,
"dinov2": NormalizedVisionConfig,
"distilbert": NormalizedTextConfig.with_args(num_attention_heads="n_heads", hidden_size="dim"),
"donut-swin": NormalizedVisionConfig,
"electra": NormalizedTextConfig,
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gemma": NormalizedTextConfigWithGQA,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"internlm2": NormalizedTextConfigWithGQA,
"llama": NormalizedTextConfigWithGQA,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"markuplm": NormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mistral": NormalizedTextConfigWithGQA,
"mixtral": NormalizedTextConfigWithGQA,
"modernbert": NormalizedTextConfig,
"mpnet": NormalizedTextConfig,
"mpt": MPTNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"olmo": NormalizedTextConfig,
"olmo2": NormalizedTextConfig,
"opt": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
"pix2struct": Pix2StructNormalizedTextConfig,
"phi": NormalizedTextConfig,
"phi3": NormalizedTextConfigWithGQA,
"poolformer": NormalizedVisionConfig,
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"segformer": NormalizedSegformerConfig,
"speech_to_text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
"trocr": TrOCRLikeNormalizedTextConfig,
"vision-encoder-decoder": NormalizedEncoderDecoderConfig,
"vit": NormalizedVisionConfig,
"whisper": WhisperLikeNormalizedTextConfig,
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"qwen2": NormalizedTextConfig,
"qwen3": NormalizedTextConfig,
"qwen3_moe": NormalizedTextConfig,
"granite": NormalizedTextConfigWithGQA,
}
@classmethod
def check_supported_model(cls, model_type: str):
if model_type not in cls._conf:
model_types = ", ".join(cls._conf.keys())
raise KeyError(
f"{model_type} model type is not supported yet in NormalizedConfig. Only {model_types} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
@classmethod
def get_normalized_config_class(cls, model_type: str) -> Type:
cls.check_supported_model(model_type)
return cls._conf[model_type]