optimum_benchmark/backends/transformers_utils.py (175 lines of code) (raw):

from contextlib import contextmanager from typing import Any, Dict, Optional, Type, Union import torch import transformers from torch import Tensor from transformers import ( AutoConfig, AutoFeatureExtractor, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer, FeatureExtractionMixin, GenerationConfig, ImageProcessingMixin, PretrainedConfig, ProcessorMixin, SpecialTokensMixin, ) from ..task_utils import TASKS_TO_AUTO_MODEL_CLASS_NAMES, map_from_synonym_task def get_transformers_auto_model_class_for_task(task: str, model_type: Optional[str] = None) -> Type["AutoModel"]: task = map_from_synonym_task(task) if task not in TASKS_TO_AUTO_MODEL_CLASS_NAMES: raise ValueError(f"Task {task} not supported for transformers") if isinstance(TASKS_TO_AUTO_MODEL_CLASS_NAMES[task], str): return getattr(transformers, TASKS_TO_AUTO_MODEL_CLASS_NAMES[task]) else: if model_type is None: raise ValueError(f"Task {task} requires a model_type to be specified") for automodel_class_name in TASKS_TO_AUTO_MODEL_CLASS_NAMES[task]: automodel_class = getattr(transformers, automodel_class_name) if model_type in automodel_class._model_mapping._model_mapping: return automodel_class raise ValueError(f"Task {task} not supported for model type {model_type}") PretrainedProcessor = Union["FeatureExtractionMixin", "ImageProcessingMixin", "SpecialTokensMixin", "ProcessorMixin"] def get_transformers_pretrained_config(model: str, **kwargs) -> "PretrainedConfig": # sometimes contains information about the model's input shapes that are not available in the config return AutoConfig.from_pretrained(model, **kwargs) def get_transformers_generation_config(model: str, **kwargs) -> Optional["GenerationConfig"]: try: # sometimes contains information about the model's input shapes that are not available in the config return GenerationConfig.from_pretrained(model, **kwargs) except Exception: return GenerationConfig() def get_transformers_pretrained_processor(model: str, **kwargs) -> Optional["PretrainedProcessor"]: try: # sometimes contains information about the model's input shapes that are not available in the config return AutoProcessor.from_pretrained(model, **kwargs) except Exception: try: return AutoFeatureExtractor.from_pretrained(model, **kwargs) except Exception: try: return AutoImageProcessor.from_pretrained(model, **kwargs) except Exception: try: return AutoTokenizer.from_pretrained(model, **kwargs) except Exception: return None def get_flat_dict(d: Dict[str, Any]) -> Dict[str, Any]: flat_dict = {} for k, v in d.items(): if isinstance(v, dict): flat_dict.update(get_flat_dict(v)) else: flat_dict[k] = v return flat_dict def get_flat_artifact_dict(artifact: Union["PretrainedConfig", "PretrainedProcessor"]) -> Dict[str, Any]: artifact_dict = {} if isinstance(artifact, ProcessorMixin): artifact_dict.update( {k: v for k, v in artifact.__dict__.items() if isinstance(v, (int, str, float, bool, list, tuple, dict))} ) for attribute in artifact.attributes: artifact_dict.update(get_flat_artifact_dict(getattr(artifact, attribute))) elif hasattr(artifact, "to_dict"): artifact_dict.update( {k: v for k, v in artifact.to_dict().items() if isinstance(v, (int, str, float, bool, list, tuple, dict))} ) else: artifact_dict.update( {k: v for k, v in artifact.__dict__.items() if isinstance(v, (int, str, float, bool, list, tuple, dict))} ) artifact_dict = get_flat_dict(artifact_dict) return artifact_dict def extract_transformers_shapes_from_artifacts( config: Optional["PretrainedConfig"] = None, processor: Optional["PretrainedProcessor"] = None, ) -> Dict[str, Any]: flat_artifacts_dict = {} if config is not None: flat_artifacts_dict.update(get_flat_artifact_dict(config)) if processor is not None: flat_artifacts_dict.update(get_flat_artifact_dict(processor)) shapes = {} # text input if "vocab_size" in flat_artifacts_dict: shapes["vocab_size"] = flat_artifacts_dict["vocab_size"] if "type_vocab_size" in flat_artifacts_dict: shapes["type_vocab_size"] = flat_artifacts_dict["type_vocab_size"] if "max_position_embeddings" in flat_artifacts_dict: shapes["max_position_embeddings"] = flat_artifacts_dict["max_position_embeddings"] elif "n_positions" in flat_artifacts_dict: shapes["max_position_embeddings"] = flat_artifacts_dict["n_positions"] # image input if "num_channels" in flat_artifacts_dict: shapes["num_channels"] = flat_artifacts_dict["num_channels"] if "image_size" in flat_artifacts_dict: image_size = flat_artifacts_dict["image_size"] elif "size" in flat_artifacts_dict: image_size = flat_artifacts_dict["size"] else: image_size = None if isinstance(image_size, (int, float)): shapes["height"] = image_size shapes["width"] = image_size elif isinstance(image_size, (list, tuple)): shapes["height"] = image_size[0] shapes["width"] = image_size[0] elif isinstance(image_size, dict) and len(image_size) == 2: shapes["height"] = list(image_size.values())[0] shapes["width"] = list(image_size.values())[1] elif isinstance(image_size, dict) and len(image_size) == 1: shapes["height"] = list(image_size.values())[0] shapes["width"] = list(image_size.values())[0] if "input_size" in flat_artifacts_dict: input_size = flat_artifacts_dict["input_size"] shapes["num_channels"] = input_size[0] shapes["height"] = input_size[1] shapes["width"] = input_size[2] # classification labels if "id2label" in flat_artifacts_dict: id2label = flat_artifacts_dict["id2label"] shapes["num_labels"] = len(id2label) elif "num_classes" in flat_artifacts_dict: shapes["num_labels"] = flat_artifacts_dict["num_classes"] # object detection labels if "num_queries" in flat_artifacts_dict: shapes["num_queries"] = flat_artifacts_dict["num_queries"] # image-text input if "patch_size" in flat_artifacts_dict: shapes["patch_size"] = flat_artifacts_dict["patch_size"] if "in_chans" in flat_artifacts_dict: shapes["num_channels"] = flat_artifacts_dict["in_chans"] if "image_seq_len" in flat_artifacts_dict: shapes["image_seq_len"] = flat_artifacts_dict["image_seq_len"] if "image_token_id" in flat_artifacts_dict: shapes["image_token_id"] = flat_artifacts_dict["image_token_id"] if "spatial_merge_size" in flat_artifacts_dict: shapes["spatial_merge_size"] = flat_artifacts_dict["spatial_merge_size"] if "do_image_splitting" in flat_artifacts_dict: shapes["do_image_splitting"] = flat_artifacts_dict["do_image_splitting"] if "temporal_patch_size" in flat_artifacts_dict: shapes["temporal_patch_size"] = flat_artifacts_dict["temporal_patch_size"] return shapes TORCH_INIT_FUNCTIONS = { "normal_": torch.nn.init.normal_, "uniform_": torch.nn.init.uniform_, "trunc_normal_": torch.nn.init.trunc_normal_, "xavier_normal_": torch.nn.init.xavier_normal_, "xavier_uniform_": torch.nn.init.xavier_uniform_, "kaiming_normal_": torch.nn.init.kaiming_normal_, "kaiming_uniform_": torch.nn.init.kaiming_uniform_, "normal": torch.nn.init.normal, "uniform": torch.nn.init.uniform, "xavier_normal": torch.nn.init.xavier_normal, "xavier_uniform": torch.nn.init.xavier_uniform, "kaiming_normal": torch.nn.init.kaiming_normal, "kaiming_uniform": torch.nn.init.kaiming_uniform, } def fast_random_tensor(tensor: "Tensor", *args: Any, **kwargs: Any) -> "Tensor": return torch.nn.init.uniform_(tensor) @contextmanager def fast_weights_init(): # Replace the initialization functions for name, init_func in TORCH_INIT_FUNCTIONS.items(): if name != "uniform_": # avoid recursion setattr(torch.nn.init, name, fast_random_tensor) try: yield finally: # Restore the original initialization functions for name, init_func in TORCH_INIT_FUNCTIONS.items(): if name != "uniform_": # avoid recursion setattr(torch.nn.init, name, init_func)