in optimum/intel/openvino/quantization.py [0:0]
def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OVCalibrationDataset:
"""
Builds a calibration dataset from a quantization config object. Namely, `quantization_config.dataset` property
is used to infer dataset name.
Args:
config (`OVQuantizationConfigBase`):
The quantization configuration object.
Returns:
A calibration dataset as an instance of `OVCalibrationDataset` containing an `nncf.Dataset` for each model component.
"""
from optimum.intel import (
OVModelForCausalLM,
OVModelForFeatureExtraction,
OVModelForMaskedLM,
OVModelForVisualCausalLM,
OVModelForZeroShotImageClassification,
OVSentenceTransformer,
)
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper
if is_diffusers_available():
from optimum.intel.openvino.modeling_diffusion import OVDiffusionPipeline
if config.dataset is None:
raise ValueError("Please provide a dataset for calibration.")
if isinstance(self.model, OVModelForCausalLM):
return self._prepare_causal_lm_calibration_data(config)
elif isinstance(
self.model, (OVModelForVisualCausalLM, _OVModelForWhisper, OVModelForZeroShotImageClassification)
):
if config.processor is None:
raise ValueError(
"`processor` must be specified in order to run data-aware quantization. Please provide it as a"
"model id, or a path to a directory containing all the required configuration files."
)
if isinstance(self.model, OVModelForVisualCausalLM):
dataset_metadata = PREDEFINED_VISUAL_LM_DATASETS[config.dataset]
return self.build_from_dataset_name(
config,
dataset_metadata["id"],
num_samples=config.num_samples,
dataset_split=dataset_metadata["split"],
trust_remote_code=config.trust_remote_code,
)
elif isinstance(self.model, _OVModelForWhisper):
dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset]
return self.build_from_dataset_name(
config,
dataset_metadata["id"],
num_samples=config.num_samples, # This is an upper bound on how many audios are needed
dataset_config_name=dataset_metadata["name"],
dataset_split=dataset_metadata["split"],
trust_remote_code=config.trust_remote_code,
streaming=dataset_metadata["streaming"],
)
elif isinstance(self.model, OVModelForZeroShotImageClassification):
dataset_metadata = PREDEFINED_TEXT_IMAGE_ENCODER_DATASETS[config.dataset]
return self.build_from_dataset_name(
config,
dataset_metadata["id"],
num_samples=None,
dataset_split=dataset_metadata["split"],
trust_remote_code=config.trust_remote_code,
streaming=dataset_metadata["streaming"],
)
else:
raise Exception
elif is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline):
if isinstance(config.dataset, str):
dataset_name = config.dataset
dataset_metadata = PREDEFINED_SD_DATASETS[dataset_name]
dataset = self.load_dataset(
dataset_name,
num_samples=config.num_samples, # This is an upper bound on how many prompts are needed
dataset_split=dataset_metadata["split"],
streaming=dataset_metadata["streaming"],
)
elif isinstance(config.dataset, list) and all(isinstance(it, str) for it in config.dataset):
dataset = config.dataset
else:
raise RuntimeError(
"Please provide dataset as one of the accepted dataset labels or as a list of string prompts."
)
return self.build_from_dataset(config, dataset)
elif isinstance(self.model, (OVModelForFeatureExtraction, OVSentenceTransformer, OVModelForMaskedLM)):
if isinstance(config.dataset, str):
dataset_metadata = PREDEFINED_LANGUAGE_DATASETS[config.dataset]
dataset = self.load_dataset(
dataset_metadata["id"],
num_samples=None,
dataset_config_name=dataset_metadata["name"],
dataset_split=dataset_metadata["split"],
trust_remote_code=config.trust_remote_code,
streaming=dataset_metadata["streaming"],
)
elif isinstance(config.dataset, list) and all(isinstance(it, str) for it in config.dataset):
dataset = datasets.Dataset.from_list([{"text": it} for it in config.dataset])
else:
raise ValueError(
"Please provide dataset as one of the accepted dataset labels or as a list of strings."
)
return self.build_from_dataset(config, dataset)
else:
raise RuntimeError("Unsupported model type for calibration dataset collection.")