optimum_benchmark/generators/input_generator.py (29 lines of code) (raw):

from typing import Any, Dict, Optional from .base import BaseGenerator from .model_generator import MODEL_TYPE_TO_GENERATORS from .task_generator import TASKS_TO_GENERATORS class InputGenerator: generator: BaseGenerator def __init__( self, task: str, input_shapes: Dict[str, int], model_shapes: Dict[str, int], model_type: Optional[str] = None, ) -> None: # input_shapes take precedence over model_shapes all_shapes = {**model_shapes, **input_shapes} if model_type in MODEL_TYPE_TO_GENERATORS: self.generator = MODEL_TYPE_TO_GENERATORS[model_type](shapes=all_shapes, with_labels=False) elif task in TASKS_TO_GENERATORS: self.generator = TASKS_TO_GENERATORS[task](shapes=all_shapes, with_labels=False) else: raise NotImplementedError( f"Task {task} is not supported for input generation. " f"Available tasks: {list(TASKS_TO_GENERATORS.keys())}. " f"Available model types: {list(MODEL_TYPE_TO_GENERATORS.keys())}. " "If you want to add support for this task or model type, " "please submit a PR or a feature request to optimum-benchmark." ) def __call__(self) -> Dict[str, Any]: task_input = self.generator() return task_input