optimum_benchmark/backends/base.py (132 lines of code) (raw):
import gc
from abc import ABC
from collections import OrderedDict
from logging import getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, ClassVar, Dict, Generic, Optional
import datasets.utils.logging as datasets_logging
import transformers.utils.logging as transformers_logging
from safetensors.torch import save_model
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel, TrainerState, set_seed
from ..hub_utils import HF_API
from ..import_utils import is_torch_available
from ..task_utils import TEXT_GENERATION_TASKS
from .config import BackendConfigT
from .diffusers_utils import (
extract_diffusers_shapes_from_model,
get_diffusers_auto_pipeline_class_for_task,
get_diffusers_pretrained_config,
)
from .timm_utils import extract_timm_shapes_from_config, get_timm_model_creator, get_timm_pretrained_config
from .transformers_utils import (
PretrainedProcessor,
extract_transformers_shapes_from_artifacts,
fast_weights_init,
get_transformers_auto_model_class_for_task,
get_transformers_generation_config,
get_transformers_pretrained_config,
get_transformers_pretrained_processor,
)
if is_torch_available():
import torch
datasets_logging.set_verbosity_error()
transformers_logging.set_verbosity_error()
class Backend(Generic[BackendConfigT], ABC):
NAME: ClassVar[str]
tmpdir: TemporaryDirectory
model_shapes: Dict[str, int]
no_weights_model_path: Optional[Path]
config: BackendConfigT
pretrained_model: PreTrainedModel
pretrained_config: Optional[PretrainedConfig]
generation_config: Optional[GenerationConfig]
pretrained_processor: Optional[PretrainedProcessor]
def __init__(self, config: BackendConfigT):
self.config = config
self.logger = getLogger(self.NAME)
self.logger.info(f"Allocating {self.NAME} backend")
self.logger.info(f"\t+ Seeding backend with {self.config.seed}")
self.seed()
if self.config.library == "diffusers":
self.logger.info("\t+ Benchmarking a Diffusers pipeline")
self.pretrained_config = get_diffusers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.automodel_loader = get_diffusers_auto_pipeline_class_for_task(self.config.task)
self.model_shapes = extract_diffusers_shapes_from_model()
self.pretrained_processor = None
self.generation_config = None
elif self.config.library == "timm":
self.logger.info("\t+ Benchmarking a Timm model")
self.pretrained_config = get_timm_pretrained_config(self.config.model)
self.model_shapes = extract_timm_shapes_from_config(self.pretrained_config)
self.automodel_loader = get_timm_model_creator()
self.pretrained_processor = None
self.generation_config = None
elif self.config.library == "llama_cpp":
self.logger.info("\t+ Benchmarking a LlamaCpp model")
self.pretrained_processor = None
self.pretrained_config = None
self.generation_config = None
self.automodel_loader = None
self.model_shapes = {}
else:
self.logger.info("\t+ Benchmarking a Transformers model")
self.automodel_loader = get_transformers_auto_model_class_for_task(self.config.task, self.config.model_type)
self.generation_config = get_transformers_generation_config(self.config.model, **self.config.model_kwargs)
self.pretrained_config = get_transformers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.pretrained_processor = get_transformers_pretrained_processor(
self.config.processor, **self.config.processor_kwargs
)
self.model_shapes = extract_transformers_shapes_from_artifacts(
self.pretrained_config, self.pretrained_processor
)
def seed(self) -> None:
set_seed(self.config.seed)
def download_pretrained_model(self) -> None:
model_snapshot_folder = HF_API.snapshot_download(
self.config.model,
revision=self.config.model_kwargs.get("revision", None),
cache_dir=self.config.model_kwargs.get("cache_dir", None),
force_download=self.config.model_kwargs.get("force_download", False),
local_files_only=self.config.model_kwargs.get("local_files_only", False),
)
if self.config.task in TEXT_GENERATION_TASKS:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
self.generation_config.save_pretrained(save_directory=model_snapshot_folder)
def create_no_weights_model_fast(self) -> None:
model_path = Path(
HF_API.hf_hub_download(self.config.model, filename="config.json", cache_dir=self.tmpdir.name)
).parent
save_model(model=torch.nn.Linear(1, 1), filename=model_path / "model.safetensors", metadata={"format": "pt"})
self.pretrained_processor.save_pretrained(save_directory=model_path)
self.pretrained_config.save_pretrained(save_directory=model_path)
if self.config.task in TEXT_GENERATION_TASKS:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
self.generation_config.save_pretrained(save_directory=model_path)
self.no_weights_model_path = model_path
def create_no_weights_model_slow(self) -> None:
self.create_no_weights_model_fast()
with fast_weights_init():
# unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model
dummy = self.automodel_loader.from_pretrained(
self.no_weights_model_path, device_map="auto", **self.config.model_kwargs
)
dummy.save_pretrained(self.no_weights_model_path)
del dummy
torch.cuda.empty_cache()
gc.collect()
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
This method is used to prepare and register the inputs before passing them to the model.
It can be used to move the inputs to the correct device, or rename their keys.
"""
return inputs
def load(self) -> None:
raise NotImplementedError("Backend must implement load method")
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
"""
This method is used to perform the forward pass of the model.
"""
raise NotImplementedError("Backend must implement forward method")
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
"""
This method is used to perform the prefill pass of the model.
"""
raise NotImplementedError("Backend must implement prefill method")
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
"""
This method is used to perform the generation pass of the model.
"""
raise NotImplementedError("Backend must implement generate method")
def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
"""
This method is used to call a whole pipeline.
"""
raise NotImplementedError("Backend must implement call method")
def train(self, **kwargs) -> TrainerState:
"""
This method is used to train the model.
"""
raise NotImplementedError("Backend must implement train method")