optimum/neuron/utils/input_generators.py (207 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dummy input generation classes."""
from typing import TYPE_CHECKING, Optional
import torch
from optimum.utils import (
DTYPE_MAPPER,
DummyAudioInputGenerator,
DummyInputGenerator,
NormalizedTextConfig,
NormalizedVisionConfig,
)
if TYPE_CHECKING:
from .argument_utils import ImageEncoderArguments
class DummyBeamValuesGenerator(DummyInputGenerator):
"""
Generates dummy beam search inputs.
"""
SUPPORTED_INPUT_NAMES = (
"beam_idx",
"beam_scores",
)
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
num_beams: int = 1,
**kwargs,
):
self.task = task
self.num_beams = num_beams
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "beam_idx":
return torch.arange(0, self.num_beams, dtype=DTYPE_MAPPER.pt(int_dtype))
elif input_name == "beam_scores":
return torch.zeros((self.num_beams,), dtype=DTYPE_MAPPER.pt(float_dtype))
class WhisperDummyTextInputGenerator(DummyInputGenerator):
"""
Generates dummy inputs for Whisper decoder.
"""
SUPPORTED_INPUT_NAMES = (
"decoder_input_ids",
"encoder_hidden_states",
)
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int,
sequence_length: int = 1,
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = normalized_config.vocab_size
self.normalized_config = normalized_config
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "decoder_input_ids":
if self.sequence_length == 1:
return torch.full(
(self.batch_size, 1), self.normalized_config.decoder_start_token_id, dtype=torch.long
)
else:
shape = (self.batch_size, self.sequence_length)
return self.random_int_tensor(
shape, max_value=self.vocab_size, min_value=0, framework=framework, dtype=int_dtype
)
elif input_name == "encoder_hidden_states":
shape = (self.batch_size, self.normalized_config.max_source_positions, self.normalized_config.hidden_size)
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
class DummyMaskedPosGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("masked_pos", "bool_masked_pos")
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int,
**kwargs,
):
self.task = task
self.image_size = getattr(normalized_config, "image_size", None)
self.patch_size = getattr(normalized_config, "patch_size", None)
self.batch_size = batch_size
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
num_patches = (self.image_size // self.patch_size) ** 2
masked_pos = torch.randint(low=0, high=2, size=(self.batch_size, num_patches))
if input_name == "masked_pos":
return masked_pos
elif input_name == "bool_masked_pos":
return masked_pos.bool()
class DummyControNetInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = (
# ControlNet inputs
"timestep",
"encoder_hidden_states", # depending on the hidden_size of text encoder
"controlnet_cond",
"conditioning_scale",
# ControlNet outputs -> UNet inputs
"down_block_additional_residuals",
"mid_block_additional_residual",
)
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int,
sequence_length: Optional[int] = None,
num_channels: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
vae_scale_factor: Optional[int] = None,
encoder_hidden_size: Optional[int] = None,
**kwargs,
):
self.task = task
self.normalized_config = normalized_config
self.batch_size = batch_size
self.sequence_length = sequence_length
self.num_channels = num_channels
self.height = height
self.width = width
self.vae_scale_factor = vae_scale_factor
self.text_encoder_hidden_size = encoder_hidden_size
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "timestep":
shape = [self.batch_size]
return self.random_int_tensor(shape, max_value=999, framework=framework, dtype=int_dtype)
elif input_name == "encoder_hidden_states":
shape = (self.batch_size, self.sequence_length, self.text_encoder_hidden_size)
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
elif input_name == "controlnet_cond":
num_channels = getattr(
self.normalized_config, "conditioning_channels", 3
) # num_channels = 3 since `do_convert_rgb=True`
shape = (
self.batch_size,
num_channels,
self.height * self.vae_scale_factor,
self.width * self.vae_scale_factor,
)
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
elif input_name == "conditioning_scale":
return torch.tensor([1.0])
elif input_name == "down_block_additional_residuals":
sample_shape = (self.batch_size, self.normalized_config.block_out_channels[0], self.height, self.width)
sample = self.random_float_tensor(sample_shape, framework=framework, dtype=float_dtype)
down_block_res_samples = (sample,)
num_past_cross_attn_blocks = 0
height = self.height
width = self.width
for idx, down_block_type in enumerate(self.normalized_config.down_block_types):
res_samples = ()
shape = (self.batch_size, self.normalized_config.block_out_channels[idx], height, width)
for _ in range(self.normalized_config.layers_per_block):
res_samples += (self.random_float_tensor(shape, framework=framework, dtype=float_dtype),)
if idx != len(self.normalized_config.down_block_types) - 1:
# add output of downsampler
num_past_cross_attn_blocks += 1
height = height // 2
width = width // 2
shape = (self.batch_size, self.normalized_config.block_out_channels[idx], height, width)
res_samples += (self.random_float_tensor(shape, framework=framework, dtype=float_dtype),)
down_block_res_samples += res_samples
return down_block_res_samples
elif input_name == "mid_block_additional_residual":
num_cross_attn_blocks = self.normalized_config.down_block_types.count("CrossAttnDownBlock2D")
out_channels = self.normalized_config.block_out_channels[-1]
shape = (
self.batch_size,
out_channels,
self.height // 2**num_cross_attn_blocks,
self.width // 2**num_cross_attn_blocks,
)
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
class DummyIPAdapterInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = (
# Unet extra inputs
"image_embeds", # If `unet.encoder_hid_proj.image_projection_layers` are instances of `IPAdapterFullImageProjection`, eg. sd.
"image_enc_hidden_states", # If `unet.encoder_hid_proj.image_projection_layers` are instances of `ImageProjection`, eg. sdxl.
"ip_adapter_masks",
)
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int,
image_encoder_shapes: Optional["ImageEncoderArguments"] = None,
**kwargs,
):
self.task = task
self.normalized_config = normalized_config
self.batch_size = batch_size
self.image_encoder_shapes = image_encoder_shapes
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "image_enc_hidden_states":
shape = [
self.batch_size,
1,
self.image_encoder_shapes.sequence_length,
self.image_encoder_shapes.hidden_size,
]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
elif input_name == "image_embeds":
shape = [self.batch_size, 1, self.image_encoder_shapes.projection_dim]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
elif input_name == "ip_adapter_masks":
shape = [
self.batch_size,
1,
self.image_encoder_shapes.sequence_length,
self.image_encoder_shapes.hidden_size,
]
return self.random_int_tensor(shape, framework=framework, dtype=int_dtype)
# copied from https://github.com/huggingface/optimum/blob/171020c775cec6ff77826c3f5f5e5c1498b23f81/optimum/exporters/onnx/model_configs.py#L1363C1-L1368C111
class ASTDummyAudioInputGenerator(DummyAudioInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = [self.batch_size, self.normalized_config.max_length, self.normalized_config.num_mel_bins]
if input_name == "input_values":
return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework, dtype=float_dtype)
return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype)