in optimum_benchmark/backends/pytorch/config.py [0:0]
def __post_init__(self):
super().__post_init__()
if self.model_kwargs.get("torch_dtype", None) is not None:
raise ValueError(
"`torch_dtype` is an explicit argument in the PyTorch backend config. "
"Please remove it from the `model_kwargs` and set it in the backend config directly."
)
if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES:
raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")
if self.autocast_dtype is not None and self.autocast_dtype not in AMP_DTYPES:
raise ValueError(f"`autocast_dtype` must be one of {AMP_DTYPES}. Got {self.autocast_dtype} instead.")
if self.quantization_scheme is not None:
LOGGER.warning(
"`backend.quantization_scheme` is deprecated and will be removed in a future version. "
"Please use `quantization_config.quant_method` instead."
)
if self.quantization_config is None:
self.quantization_config = {"quant_method": self.quantization_scheme}
else:
self.quantization_config["quant_method"] = self.quantization_scheme
if self.quantization_config is not None:
self.quantization_config = dict(
QUANTIZATION_CONFIGS.get(self.quantization_scheme, {}), # default config
**self.quantization_config, # user config (overwrites default)
)