optimum_benchmark/generators/base.py (40 lines of code) (raw):

import logging import random import string from abc import ABC from typing import Dict, List, Tuple import torch LOGGER = logging.getLogger("generators") class BaseGenerator(ABC): def __init__(self, shapes: Dict[str, int], with_labels: bool): self.shapes = shapes self.with_labels = with_labels def assert_not_missing_shapes(self, required_shapes: List[str]): for shape in required_shapes: assert self.shapes.get(shape, None) is not None, ( f"{shape} either couldn't be inferred automatically from model artifacts or should be provided by the user. " f"Please provide it under `scenario.input_shapes.{shape}` or open an issue/PR in optimum-benchmark repository. " ) @staticmethod def generate_constant_integers(value: int, shape: Tuple[int]): return torch.full(shape, value, dtype=torch.int64) @staticmethod def generate_constant_floats(value: float, shape: Tuple[int]): return torch.full(shape, value, dtype=torch.float32) @staticmethod def generate_random_integers(min_value: int, max_value: int, shape: Tuple[int]): return torch.randint(min_value, max_value, shape) @staticmethod def generate_random_floats(min_value: float, max_value: float, shape: Tuple[int]): return torch.rand(shape) * (max_value - min_value) + min_value @staticmethod def generate_ranges(start: int, stop: int, shape: Tuple[int]): return torch.arange(start, stop).repeat(shape[0], 1) @staticmethod def generate_random_strings(num_seq: int) -> List[str]: return [ "".join(random.choice(string.ascii_letters + string.digits) for _ in range(random.randint(10, 100))) for _ in range(num_seq) ] def __call__(self): raise NotImplementedError("Generator must implement __call__ method")