optimum_benchmark/generators/model_generator.py (210 lines of code) (raw):
import logging
import torch
from .base import BaseGenerator
LOGGER = logging.getLogger("generators")
DEFAULT_VOCAB_SIZE = 2
class IdeficsGenerator(BaseGenerator):
def input_ids(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images", "image_token_id"])
text_tokens = self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE),
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
image_tokens = self.generate_constant_integers(
value=self.shapes["image_token_id"],
shape=(self.shapes["batch_size"], self.shapes["num_images"]),
)
return torch.cat((text_tokens, image_tokens), dim=1)
def attention_mask(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images"])
return self.generate_constant_integers(
value=1, # no sparsity
shape=(
self.shapes["batch_size"],
self.shapes["sequence_length"] + self.shapes["num_images"],
),
)
def pixel_values(self):
self.assert_not_missing_shapes(["batch_size", "num_images", "num_channels", "height", "width"])
return self.generate_random_floats(
min_value=0,
max_value=1,
shape=(
self.shapes["batch_size"],
self.shapes["num_images"],
self.shapes["num_channels"],
self.shapes["height"],
self.shapes["width"],
),
)
def image_attention_mask(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images"])
return self.generate_constant_integers(
value=1, # no sparsity
shape=(
self.shapes["batch_size"],
self.shapes["sequence_length"] + self.shapes["num_images"],
self.shapes["num_images"],
),
)
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["pixel_values"] = self.pixel_values()
dummy["attention_mask"] = self.attention_mask()
dummy["image_attention_mask"] = self.image_attention_mask()
if self.with_labels:
dummy["labels"] = self.input_ids()
return dummy
class Idefics2Generator(BaseGenerator):
def input_ids(self):
self.assert_not_missing_shapes(
["batch_size", "sequence_length", "num_images", "image_seq_len", "image_token_id", "do_image_splitting"]
)
text_tokens = self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE),
shape=(self.shapes["batch_size"], self.shapes["sequence_length"]),
)
image_tokens = self.generate_constant_integers(
value=self.shapes["image_token_id"],
shape=(
self.shapes["batch_size"],
self.shapes["num_images"]
* self.shapes["image_seq_len"]
* (5 if self.shapes["do_image_splitting"] else 1),
),
)
return torch.cat((text_tokens, image_tokens), dim=1)
def attention_mask(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images", "do_image_splitting"])
return self.generate_constant_integers(
value=1, # no sparsity
shape=(
self.shapes["batch_size"],
self.shapes["sequence_length"]
+ self.shapes["num_images"]
* self.shapes["image_seq_len"]
* (5 if self.shapes["do_image_splitting"] else 1),
),
)
def pixel_values(self):
self.assert_not_missing_shapes(
["batch_size", "num_images", "num_channels", "height", "width", "do_image_splitting"]
)
return self.generate_random_floats(
min_value=0,
max_value=1,
shape=(
self.shapes["batch_size"],
self.shapes["num_images"] * (5 if self.shapes["do_image_splitting"] else 1),
self.shapes["num_channels"],
self.shapes["height"],
self.shapes["width"],
),
)
def pixel_attention_mask(self):
self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images", "do_image_splitting"])
return self.generate_constant_integers(
value=1, # no sparsity
shape=(
self.shapes["batch_size"],
self.shapes["num_images"] * (5 if self.shapes["do_image_splitting"] else 1),
self.shapes["height"],
self.shapes["width"],
),
)
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["pixel_values"] = self.pixel_values()
dummy["attention_mask"] = self.attention_mask()
dummy["pixel_attention_mask"] = self.pixel_attention_mask()
print("input_ids", dummy["input_ids"].shape)
print("pixel_values", dummy["pixel_values"].shape)
print("attention_mask", dummy["attention_mask"].shape)
print("pixel_attention_mask", dummy["pixel_attention_mask"].shape)
if self.with_labels:
dummy["labels"] = self.input_ids()
return dummy
class Qwen2VLGenerator(BaseGenerator):
def input_ids(self):
self.assert_not_missing_shapes(
[
"batch_size",
"sequence_length",
"num_images",
"num_channels",
"height",
"width",
"patch_size",
"temporal_patch_size",
"spatial_merge_size",
"image_token_id",
]
)
text_tokens = self.generate_random_integers(
min_value=0,
max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE),
shape=(
self.shapes["batch_size"],
self.shapes["sequence_length"],
),
)
image_tokens = self.generate_constant_integers(
value=self.shapes["image_token_id"],
shape=(
self.shapes["batch_size"],
int(
self.shapes["num_images"]
* self.shapes["height"]
* self.shapes["width"]
/ self.shapes["temporal_patch_size"]
/ self.shapes["spatial_merge_size"]
/ self.shapes["patch_size"] ** 2
),
),
)
return torch.cat((text_tokens, image_tokens), dim=1)
def pixel_values(self):
self.assert_not_missing_shapes(
["num_images", "num_channels", "height", "width", "patch_size", "temporal_patch_size"]
)
return self.generate_random_floats(
min_value=0,
max_value=1,
shape=(
self.shapes["num_images"]
* int(self.shapes["height"] / self.shapes["patch_size"])
* int(self.shapes["width"] / self.shapes["patch_size"]),
self.shapes["num_channels"]
* self.shapes["patch_size"]
* self.shapes["patch_size"]
* self.shapes["temporal_patch_size"],
),
)
def image_grid_thw(self):
self.assert_not_missing_shapes(["num_images", "height", "width", "patch_size"])
return torch.tensor(
[
[
self.shapes["num_images"],
int(self.shapes["height"] / self.shapes["patch_size"]),
int(self.shapes["width"] / self.shapes["patch_size"]),
]
]
)
def __call__(self):
dummy = {}
dummy["input_ids"] = self.input_ids()
dummy["pixel_values"] = self.pixel_values()
dummy["image_grid_thw"] = self.image_grid_thw()
if self.with_labels:
dummy["labels"] = self.input_ids()
return dummy
MODEL_TYPE_TO_GENERATORS = {
"idefics": IdeficsGenerator,
"idefics2": Idefics2Generator,
"qwen2_vl": Qwen2VLGenerator,
}