optimum_benchmark/generators/dataset_generator.py (33 lines of code) (raw):

from typing import Dict, Optional from datasets import Dataset from .base import BaseGenerator from .model_generator import MODEL_TYPE_TO_GENERATORS from .task_generator import TASKS_TO_GENERATORS class DatasetGenerator: generator: BaseGenerator def __init__( self, task: str, dataset_shapes: Dict[str, int], model_shapes: Dict[str, int], model_type: Optional[str] = None, ) -> None: # dataset_shapes take precedence over model_shapes all_shapes = {**model_shapes, **dataset_shapes} all_shapes["batch_size"] = all_shapes.pop("dataset_size", None) if model_type in MODEL_TYPE_TO_GENERATORS: self.generator = MODEL_TYPE_TO_GENERATORS[model_type](shapes=all_shapes, with_labels=True) elif task in TASKS_TO_GENERATORS: self.generator = TASKS_TO_GENERATORS[task](shapes=all_shapes, with_labels=True) else: raise NotImplementedError( f"Task {task} is not supported for dataset generation. " f"Available tasks: {list(TASKS_TO_GENERATORS.keys())}. " f"Available model types: {list(MODEL_TYPE_TO_GENERATORS.keys())}. " "If you want to add support for this task or model type, " "please submit a PR or a feature request to optimum-benchmark." ) def __call__(self) -> Dataset: task_dataset = self.generator() task_dataset = Dataset.from_dict(task_dataset) task_dataset.set_format(type="torch", columns=list(task_dataset.features.keys())) return task_dataset