optimum_benchmark/generators/task_generator.py (319 lines of code) (raw):
import logging
from .base import BaseGenerator
LOGGER = logging.getLogger("generators")
DEFAULT_NUM_LABELS = 2
DEFAULT_VOCAB_SIZE = 2
DEFAULT_TYPE_VOCAB_SIZE = 2
class TextGenerator(BaseGenerator):
def input_ids(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE),
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
def attention_mask(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_constant_integers(
value=1, # no sparsity
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
def token_type_ids(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("type_vocab_size", DEFAULT_TYPE_VOCAB_SIZE),
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
def position_ids(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_ranges(
start=0,
stop=self.shapes["sequence_length"],
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
def requires_token_type_ids(self):
return self.shapes.get("type_vocab_size", None) is not None and self.shapes["type_vocab_size"] > 1
def requires_position_ids(self):
return (
self.shapes.get("max_position_embeddings", None) is not None and self.shapes["max_position_embeddings"] > 1
)
class ImageGenerator(BaseGenerator):
def pixel_values(self):
self.assert_not_missing_shapes(["batch_size", "num_channels", "height", "width"])
return self.generate_random_floats(
min_value=0,
max_value=1,
shape=(
self.shapes["batch_size"],
self.shapes["num_channels"],
self.shapes["height"],
self.shapes["width"],
),
)
class AudioGenerator(BaseGenerator):
def input_values(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_random_floats(
min_value=-1,
max_value=1,
shape=(
self.shapes["batch_size"],
self.shapes["sequence_length"],
),
)
def input_features(self):
self.assert_not_missing_shapes(["batch_size", "feature_size", "nb_max_frames"])
return self.generate_random_floats(
min_value=-1,
max_value=1,
shape=(
self.shapes["batch_size"],
self.shapes["feature_size"],
self.shapes["nb_max_frames"],
),
)
class TextClassificationGenerator(TextGenerator):
def labels(self):
self.assert_not_missing_shapes(["batch_size"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS),
shape=(self.shapes["batch_size"],),
)
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
if self.requires_token_type_ids():
dummy["token_type_ids"] = self.token_type_ids()
if self.requires_position_ids():
dummy["position_ids"] = self.position_ids()
if self.with_labels:
dummy["labels"] = self.labels()
return dummy
class TokenClassificationGenerator(TextGenerator):
def labels(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS),
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
if self.requires_token_type_ids():
dummy["token_type_ids"] = self.token_type_ids()
if self.requires_position_ids():
dummy["position_ids"] = self.position_ids()
if self.with_labels:
dummy["labels"] = self.labels()
return dummy
class TextGenerationGenerator(TextGenerator):
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
if self.with_labels:
dummy["labels"] = self.input_ids()
return dummy
class Text2TextGenerationGenerator(TextGenerator):
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
if self.with_labels:
dummy["labels"] = self.input_ids()
return dummy
class QuestionAnsweringGenerator(TextGenerator):
def start_positions(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["sequence_length"],
shape=(self.shapes["batch_size"],),
)
def end_positions(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["sequence_length"],
shape=(self.shapes["batch_size"],),
)
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
dummy["token_type_ids"] = self.token_type_ids()
if self.with_labels:
dummy["start_positions"] = self.start_positions()
dummy["end_positions"] = self.end_positions()
return dummy
class MaskedLanguageModelingGenerator(TextGenerator):
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
if self.requires_token_type_ids():
dummy["token_type_ids"] = self.token_type_ids()
if self.requires_position_ids():
dummy["position_ids"] = self.position_ids()
if self.with_labels:
dummy["labels"] = self.input_ids()
return dummy
class MultipleChoiceGenerator(TextGenerator):
def input_ids(self):
self.assert_not_missing_shapes(["batch_size", "num_choices", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE),
shape=(self.shapes["batch_size"], self.shapes["num_choices"], self.shapes["sequence_length"]),
)
def attention_mask(self):
self.assert_not_missing_shapes(["batch_size", "num_choices", "sequence_length"])
return self.generate_constant_integers(
value=1, # no sparsity
shape=(self.shapes["batch_size"], self.shapes["num_choices"], self.shapes["sequence_length"]),
)
def token_type_ids(self):
self.assert_not_missing_shapes(["batch_size", "num_choices", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("type_vocab_size", DEFAULT_TYPE_VOCAB_SIZE),
shape=(self.shapes["batch_size"], self.shapes["num_choices"], self.shapes["sequence_length"]),
)
def labels(self):
self.assert_not_missing_shapes(["batch_size", "num_choices"])
return self.generate_random_integers(
min_value=0, max_value=self.shapes["num_choices"], shape=(self.shapes["batch_size"],)
)
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
if self.requires_token_type_ids():
dummy["token_type_ids"] = self.token_type_ids()
if self.with_labels:
dummy["label"] = self.labels()
return dummy
class ImageClassificationGenerator(ImageGenerator):
def labels(self):
self.assert_not_missing_shapes(["batch_size"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS),
shape=(self.shapes["batch_size"],),
)
def __call__(self):
dummy = {}
dummy["pixel_values"] = self.pixel_values()
if self.with_labels:
dummy["labels"] = self.labels()
return dummy
class ObjectDetectionGenerator(ImageGenerator):
def labels(self):
self.assert_not_missing_shapes(["batch_size", "num_queries"])
return [
{
"class_labels": self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS),
shape=(self.shapes["num_queries"],),
),
"boxes": self.generate_random_floats(min_value=-1, max_value=1, shape=(self.shapes["num_queries"], 4)),
}
for _ in range(self.shapes["batch_size"])
]
def __call__(self):
dummy = {}
dummy["pixel_values"] = self.pixel_values()
if self.with_labels:
dummy["labels"] = self.labels()
return dummy
class SemanticSegmentationGenerator(ImageGenerator):
def labels(self):
self.assert_not_missing_shapes(["batch_size", "height", "width"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS),
shape=(self.shapes["batch_size"], self.shapes["height"], self.shapes["width"]),
)
def __call__(self):
dummy = {}
dummy["pixel_values"] = self.pixel_values()
if self.with_labels:
dummy["labels"] = self.labels()
return dummy
class AudioClassificationGenerator(AudioGenerator):
def labels(self):
self.assert_not_missing_shapes(["batch_size"])
return self.generate_random_integers(
min_value=0, max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS), shape=(self.shapes["batch_size"],)
)
def __call__(self):
dummy = {}
dummy["input_values"] = self.input_values()
if self.with_labels:
dummy["labels"] = self.labels()
return dummy
class AutomaticSpeechRecognitionGenerator(AudioGenerator):
def labels(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length"])
return self.generate_random_integers(
min_value=0,
max_value=self.shapes["vocab_size"] or DEFAULT_TYPE_VOCAB_SIZE,
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
def __call__(self):
dummy = {}
dummy["input_values"] = self.input_values()
if self.with_labels:
dummy["labels"] = self.labels()
return dummy
class PromptGenerator(BaseGenerator):
def prompt(self):
self.assert_not_missing_shapes(["batch_size"])
return self.generate_random_strings(num_seq=self.shapes["batch_size"])
def __call__(self):
dummy = {}
dummy["prompt"] = self.prompt()
return dummy
class FeatureExtractionGenerator(TextGenerator, ImageGenerator):
def __call__(self):
dummy = {}
if self.shapes.get("sequence_length", None) is not None:
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
if self.requires_token_type_ids():
dummy["token_type_ids"] = self.token_type_ids()
if self.requires_position_ids():
dummy["position_ids"] = self.position_ids()
if self.shapes.get("height", None) is not None:
dummy["pixel_values"] = self.pixel_values()
return dummy
class ImageTextToTextGenerator(TextGenerator, ImageGenerator):
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["attention_mask"] = self.attention_mask()
dummy["pixel_values"] = self.pixel_values()
if self.with_labels:
dummy["labels"] = self.input_ids()
return dummy
TASKS_TO_GENERATORS = {
# transformers models tasks
"feature-extraction": FeatureExtractionGenerator,
"text-classification": TextClassificationGenerator,
"token-classification": TokenClassificationGenerator,
"text-generation": TextGenerationGenerator,
"text2text-generation": Text2TextGenerationGenerator,
"question-answering": QuestionAnsweringGenerator,
"fill-mask": MaskedLanguageModelingGenerator,
"multiple-choice": MultipleChoiceGenerator,
"image-classification": ImageClassificationGenerator,
"object-detection": ObjectDetectionGenerator,
"semantic-segmentation": SemanticSegmentationGenerator,
"image-text-to-text": ImageTextToTextGenerator,
# diffusers pipelines tasks
"text-to-image": PromptGenerator,
}