optimum_benchmark/backends/ipex/config.py (25 lines of code) (raw):

from dataclasses import dataclass from typing import Optional from ...import_utils import ipex_version from ..config import BackendConfig TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"] @dataclass class IPEXConfig(BackendConfig): name: str = "ipex" version: Optional[str] = ipex_version() _target_: str = "optimum_benchmark.backends.ipex.backend.IPEXBackend" no_weights: bool = False # ipexmodel kwargs export: Optional[bool] = None torch_dtype: Optional[str] = None def __post_init__(self): super().__post_init__() self.device = self.device.lower() if self.device not in ["cpu", "gpu"]: raise ValueError(f"IPEXBackend only supports CPU devices, got {self.device}") if self.model_kwargs.get("torch_dtype", None) is not None: raise ValueError( "`torch_dtype` is an explicit argument in the PyTorch backend config. " "Please remove it from the `model_kwargs` and set it in the backend config directly." ) if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES: raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")