optimum_benchmark/backends/config.py (97 lines of code) (raw):
import os
from abc import ABC
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any, Dict, Optional, TypeVar
from psutil import cpu_count
from ..system_utils import get_gpu_device_ids, is_nvidia_system, is_rocm_system
from ..task_utils import (
infer_library_from_model_name_or_path,
infer_model_type_from_model_name_or_path,
infer_task_from_model_name_or_path,
)
LOGGER = getLogger("backend")
@dataclass
class BackendConfig(ABC):
name: str
version: str
_target_: str
model: Optional[str] = None
processor: Optional[str] = None
task: Optional[str] = None
library: Optional[str] = None
model_type: Optional[str] = None
device: Optional[str] = None
# we use a string here instead of a list
# because it's easier to pass in a yaml or from cli
# and it's consistent with GPU environment variables
device_ids: Optional[str] = None
seed: int = 42
inter_op_num_threads: Optional[int] = None
intra_op_num_threads: Optional[int] = None
# model kwargs that are added to its init method/constructor
model_kwargs: Dict[str, Any] = field(default_factory=dict)
# processor kwargs that are added to its init method/constructor
processor_kwargs: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if self.model is None:
raise ValueError("`model` must be specified.")
if self.model_kwargs.get("token", None) is not None:
LOGGER.info(
"You have passed an argument `token` to `model_kwargs`. This is dangerous as the config cannot do encryption to protect it. "
"We will proceed to registering `token` in the environment as `HF_TOKEN` to avoid saving it or pushing it to the hub by mistake."
)
os.environ["HF_TOKEN"] = self.model_kwargs.pop("token")
if self.processor is None:
self.processor = self.model
if not self.processor_kwargs:
self.processor_kwargs = self.model_kwargs
if self.library is None:
self.library = infer_library_from_model_name_or_path(
model_name_or_path=self.model,
revision=self.model_kwargs.get("revision", None),
cache_dir=self.model_kwargs.get("cache_dir", None),
)
if self.library not in ["transformers", "diffusers", "timm", "llama_cpp"]:
raise ValueError(
f"`library` must be either `transformers`, `diffusers`, `timm` or `llama_cpp`, but got {self.library}"
)
if self.task is None:
self.task = infer_task_from_model_name_or_path(
model_name_or_path=self.model,
revision=self.model_kwargs.get("revision", None),
cache_dir=self.model_kwargs.get("cache_dir", None),
library_name=self.library,
)
if self.model_type is None:
self.model_type = infer_model_type_from_model_name_or_path(
model_name_or_path=self.model,
revision=self.model_kwargs.get("revision", None),
cache_dir=self.model_kwargs.get("cache_dir", None),
library_name=self.library,
)
if self.device is None:
self.device = "cuda" if is_nvidia_system() or is_rocm_system() else "cpu"
if ":" in self.device:
LOGGER.warning("`device` was specified using PyTorch format (e.g. `cuda:0`) which is not recommended.")
self.device = self.device.split(":")[0]
self.device_ids = self.device.split(":")[1]
LOGGER.warning(f"`device` and `device_ids` are now set to `{self.device}` and `{self.device_ids}`.")
if self.device not in ["cuda", "cpu", "mps", "xla", "gpu"]:
raise ValueError(f"`device` must be either `cuda`, `cpu`, `mps`, `xla` or `gpu`, but got {self.device}")
if self.device == "cuda":
if self.device_ids is None:
LOGGER.warning("`device_ids` was not specified, using all available GPUs.")
self.device_ids = get_gpu_device_ids()
LOGGER.warning(f"`device_ids` is now set to `{self.device_ids}` based on system configuration.")
if is_nvidia_system():
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = self.device_ids
LOGGER.info(f"CUDA_VISIBLE_DEVICES was set to {os.environ['CUDA_VISIBLE_DEVICES']}.")
elif is_rocm_system():
os.environ["ROCR_VISIBLE_DEVICES"] = self.device_ids
LOGGER.info(f"ROCR_VISIBLE_DEVICES was set to {os.environ['ROCR_VISIBLE_DEVICES']}.")
else:
raise RuntimeError("CUDA device is only supported on systems with NVIDIA or ROCm drivers.")
if self.inter_op_num_threads is not None:
if self.inter_op_num_threads == -1:
self.inter_op_num_threads = cpu_count()
if self.intra_op_num_threads is not None:
if self.intra_op_num_threads == -1:
self.intra_op_num_threads = cpu_count()
BackendConfigT = TypeVar("BackendConfigT", bound=BackendConfig)