optimum/exporters/onnx/model_patcher.py (951 lines of code) (raw):
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
import inspect
import math
import sys
import types
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import transformers
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
from ...utils import is_transformers_version, logging
from ._traceable_cache import TraceableCache
if is_transformers_version(">=", "4.35"):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
if is_transformers_version(">=", "4.36"):
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
if is_transformers_version(">=", "4.43") and is_transformers_version("<", "4.48"):
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention
if is_transformers_version(">=", "4.42"):
from transformers.cache_utils import SlidingWindowCache, StaticCache
if is_transformers_version(">=", "4.48"):
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
from transformers.integrations.sdpa_attention import repeat_kv, sdpa_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
from .base import OnnxConfig
logger = logging.get_logger(__name__)
def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None):
"""
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
Args:
attribute_name (`str`):
The name of attribute to patch.
patch (`Any`):
The patch for the attribute.
module_name_prefix (`Optional[str]`, defaults to `None`):
If set, only module names starting with this prefix will be considered for patching.
"""
# sys.modules may be updated while being iterated over, hence the list copy.
for name in list(sys.modules):
module = sys.modules[name]
if module_name_prefix is not None and not name.startswith(module_name_prefix):
continue
if hasattr(module, attribute_name):
setattr(module, attribute_name, patch)
def override_arguments(args, kwargs, forward_signature, model_kwargs: Dict[str, Any]):
"""
Override the args and kwargs with the argument values from model_kwargs, following the signature forward_signature corresponding to args and kwargs.
"""
args = list(args)
for argument in model_kwargs:
if argument in forward_signature.parameters:
argument_index = list(forward_signature.parameters.keys()).index(argument)
if argument in kwargs or len(args) <= argument_index:
kwargs[argument] = model_kwargs[argument]
else:
args[argument_index] = model_kwargs[argument]
else:
kwargs[argument] = model_kwargs[argument]
return args, kwargs
@dataclasses.dataclass
class PatchingSpec:
"""
Data class that holds patching specifications.
Args:
o: Module / object where the op to patch is located
name: Name of the op to monkey patch
custom_op: Custom op that patches the original op
orig_op: Original op that is being patched
op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
It is useful for ops that are class or static methods for instance.
"""
o: Any
name: str
custom_op: Callable
orig_op: Optional[Callable] = None
op_wrapper: Optional[Callable] = None
# An ONNX-export-compatible version of `tensor.unfold`. Without this, we get:
# torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.
# See https://github.com/pytorch/pytorch/issues/81871 for more information
def onnx_compatible_unfold(input_tensor, dimension, size, step):
"""
Custom implementation of torch.unfold without using torch.unfold.
Args:
input_tensor (torch.Tensor): The input tensor.
dimension (int): The dimension to unfold.
size (int): The size of each slice.
step (int): The step size between slices.
Returns:
torch.Tensor: The unfolded tensor.
"""
# Check if dimension is within the valid range
if not (-input_tensor.dim() <= dimension < input_tensor.dim()):
raise ValueError(
f"Dimension out of range (expected to be in range of [{-input_tensor.dim()}, {input_tensor.dim() - 1}], but got {dimension})"
)
# Normalize negative dimension
dimension = dimension % input_tensor.dim()
# Compute the shape of the unfolded output
input_size = input_tensor.size(dimension)
num_slices = (input_size - size) // step + 1
# Permute dimension to the end for easier indexing
input_tensor = input_tensor.transpose(dimension, -1)
# Extract slices
slices = []
for i in range(num_slices):
start = i * step
end = start + size
slices.append(input_tensor[..., start:end])
# Stack slices and permute dimensions back
result = torch.stack(slices, dim=-2).transpose(dimension, -2)
return result
# An ONNX-export-compatible version of `tensor.repeat_interleave`.
# Without this, we get the following error: https://github.com/pytorch/pytorch/issues/145100
# NOTE: This implementation is only necessary for export with dynamo=False (dynamo=True works correctly).
# and can be removed once Optimum switches to dynamo-based exports
def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None, output_size=None):
"""
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.
Args:
input_tensor (torch.Tensor): The input tensor.
repeats (int or torch.Tensor): The number of repetitions for each element.
dim (int, optional): The dimension along which to repeat. Defaults to None.
Returns:
torch.Tensor: The repeated tensor.
"""
if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0):
if dim is None:
return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten()
repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device)
if dim is None:
return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0)
if dim != 0:
input_tensor = input_tensor.transpose(0, dim)
# Create expand mask
max_repeats = repeats.max()
expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:])
mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1)
result = expanded[mask]
if dim != 0:
result = result.transpose(0, dim)
return result
original_linal_norm = torch.linalg.norm
# Custom implementation of torch.linalg.matrix_norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
def onnx_compatible_linalg_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> torch.Tensor:
"""
Custom implementation of torch.linalg.norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
It only handles the case of matrix norm with ord=2, otherwise it uses the original implementation.
"""
if ord == 2:
if dim is None:
dim = (-2, -1)
norm = torch.sqrt(torch.sum(torch.square(x), dim=dim, keepdim=keepdim))
if dtype is not None:
norm = norm.to(dtype)
if out is not None:
out.copy_(norm)
return norm
return original_linal_norm(x, ord=ord, dim=dim, keepdim=keepdim, dtype=dtype, out=out)
UNSUPPORTED_OPS_PATCHING_SPEC = [
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
PatchingSpec(torch.linalg, "norm", onnx_compatible_linalg_norm, original_linal_norm),
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
]
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)]
class ModelPatcher:
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model = model
patching_specs = config.PATCHING_SPECS or []
patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC)
patching_specs.extend(CACHE_PATCHING_SPEC)
self._patching_specs = []
for spec in patching_specs:
final_spec = spec
if spec.orig_op is None:
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
self._patching_specs.append(final_spec)
self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call"
self.orig_forward = getattr(self._model, self.orig_forward_name)
self.model_kwargs = model_kwargs if model_kwargs is not None else {}
# TODO: remove that once we got rid of OnnxConfigWithLoss or we implemented it better.
if config.__class__.__name__ == "OnnxConfigWithLoss":
self.real_config = config._onnx_config
else:
self.real_config = config
allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past
@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
if is_transformers_version(">=", "4.48"):
if "past_key_values" in signature.parameters:
pkv_index = list(signature.parameters.keys()).index("past_key_values")
if (
pkv_index < len(args) # pkv is in args
and isinstance(args[pkv_index], (list, tuple))
and isinstance(args[pkv_index][0], (list, tuple))
):
if len(args[pkv_index][0]) == 2:
args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index])
elif len(args[pkv_index][0]) == 4:
args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index])
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements"
)
elif (
"past_key_values" in kwargs # pkv is in kwargs
and isinstance(kwargs["past_key_values"], (list, tuple))
and isinstance(kwargs["past_key_values"][0], (list, tuple))
):
if len(kwargs["past_key_values"][0]) == 2:
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
elif len(kwargs["past_key_values"][0]) == 4:
kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(
kwargs["past_key_values"]
)
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
)
outputs = self.orig_forward(*args, **kwargs)
# This code block handles different cases of the filterd_outputs input to align it with the expected
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
# contains the output names of the model. In the case of Timm classification models, the output
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
# match the outputs in order.
filtered_outputs = {}
if isinstance(outputs, dict):
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filtered_outputs[name] = value
elif isinstance(outputs, (list, tuple)):
outputs_list = list(config.outputs.keys())
filtered_outputs = dict(zip(outputs_list, outputs))
else:
if len(config.outputs) > 1:
num_outputs = len(config.outputs)
outputs_str = ", ".join(config.outputs.keys())
raise ValueError(
f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
)
else:
name = list(config.outputs.keys())[0]
filtered_outputs[name] = outputs
name = list(config.outputs.keys())[0]
filtered_outputs[name] = outputs
if is_transformers_version(">=", "4.48"):
if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
return filtered_outputs
self.patched_forward = patched_forward
def patch_ops(self):
for spec in self._patching_specs:
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
setattr(spec.o, spec.name, custom_op)
def restore_ops(self):
for spec in self._patching_specs:
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op)
def __enter__(self):
self.patch_ops()
setattr(self._model, self.orig_forward_name, self.patched_forward)
def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
setattr(self._model, self.orig_forward_name, self.orig_forward)
def __call__(self, *args, **kwargs):
if getattr(self._model, self.orig_forward_name) is self.orig_forward:
logger.warning("Running the non-patched model")
return self._model(*args, **kwargs)
class Seq2SeqModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.48"):
# this is required when gpt2 is used as decoder in any
# encoder-decoder model with cross attention blocks
ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.48"):
ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past
# use_cache is by default set to False with pix2struct, so we need to set it to
# True to export with past key value
if model.config.model_type == "pix2struct" and allow_past_in_outputs:
model.config.text_config.use_cache = True
# Re-use the patched forward method from the parent class
self.super_patched_forward = self.patched_forward
@functools.wraps(self.super_patched_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.super_patched_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
outputs = self.super_patched_forward(*args, **kwargs)
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filtered_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
if name != "past_key_values":
if self.real_config._behavior == "decoder" and name == "encoder_last_hidden_state":
# Who cares about the encoder outputs in the decoder?
continue
else:
filtered_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder"
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
):
filtered_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filtered_outputs[name] = tuple([v[:2] for v in value])
return filtered_outputs
self.patched_forward = patched_forward
def patched_sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
class VisionEncoderDecoderPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
use_cache = hasattr(self.real_config, "use_past")
if config._behavior == "decoder" and model.config.decoder.model_type == "trocr" and use_cache:
model.decoder.model.decoder.config.use_cache = True
if is_transformers_version(">=", "4.39"):
def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float):
return expanded_mask
else:
def _unmask_unattended_patched(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
):
return expanded_mask
def _make_causal_mask_patched(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
# We add self in the signature because `self._make_causal_mask` is used elsewhere in the class definition, despite the method being a staticmethod.
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
# NOTE: adding dtype=torch.int64 here for triu to be supported by ORT: https://github.com/microsoft/onnxruntime/issues/16189
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int64), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Adapted from _prepare_4d_causal_attention_mask
def _prepare_4d_causal_attention_mask_for_sdpa_patched(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
# 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
# NOTE: For the ONNX export we remove the setting of attention_mask to None in some specific cases, and we do NOT call _unmask_unattended
# that can not be exported to ONNX and is very specific to PyTorch memory-efficient attention backend anyway.
return attention_mask
class DecoderModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.35"):
AttentionMaskConverter._make_causal_mask = staticmethod(_make_causal_mask_patched)
if is_transformers_version(">=", "4.36"):
AttentionMaskConverter._unmask_unattended = staticmethod(_unmask_unattended_patched)
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa",
_prepare_4d_causal_attention_mask_for_sdpa_patched,
module_name_prefix="transformers",
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.35"):
AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal_mask)
if is_transformers_version(">=", "4.36"):
AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended)
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa",
self.original_prepare_4d_causal_attention_mask_for_sdpa,
module_name_prefix="transformers",
)
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
if is_transformers_version(">=", "4.35"):
self.original_make_causal_mask = AttentionMaskConverter._make_causal_mask
if is_transformers_version(">=", "4.36"):
self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended
self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
def falcon_build_alibi_tensor_patched(
attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
) -> torch.Tensor:
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
# NOTE: remove the .bfloat16() cast here as PyTorch ONNX export rather casts to complex128 if this is used, resulting in a onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph error.
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
class FalconModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
self.patch_ops()
if self.real_config.task == "text-generation":
patch_everywhere(
"build_alibi_tensor",
falcon_build_alibi_tensor_patched,
module_name_prefix="transformers.models.falcon.modeling_falcon",
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self.restore_ops()
setattr(self._model, self.orig_forward_name, self.orig_forward)
if self.real_config.task == "text-generation":
patch_everywhere(
"build_alibi_tensor",
self.build_alibi_tensor_original,
module_name_prefix="transformers.models.falcon.modeling_falcon",
)
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
self.build_alibi_tensor_original = transformers.models.falcon.modeling_falcon.build_alibi_tensor
class WavLMModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past
@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
model_kwargs = self.model_kwargs
# setting output_attentions=True in the model input to avoid calling torch.nn.functional.scaled_dot_product_attention
# in https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/wavlm/modeling_wavlm.py#L496
# that calls https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/functional.py#L5334
model_kwargs["output_attentions"] = True
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs)
outputs = self.orig_forward(*args, **kwargs)
filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
return filterd_outputs
self.patched_forward = patched_forward
class MgpstrModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
# logits is a tuple, so we unpack it and return them as separate outputs
char_logits, bpe_logits, wp_logits = self.orig_forward(*args, **kwargs).logits
return {
"char_logits": char_logits,
"bpe_logits": bpe_logits,
"wp_logits": wp_logits,
}
self.patched_forward = patched_forward
class SAMModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
def patched_forward(
pixel_values=None,
input_points=None,
input_labels=None,
image_embeddings=None,
image_positional_embeddings=None,
return_dict=True,
**kwargs,
):
if config.variant == "monolith":
return self.orig_forward(
pixel_values=pixel_values,
input_points=input_points,
input_labels=input_labels,
image_embeddings=image_embeddings,
return_dict=return_dict,
**kwargs,
)
elif config.variant == "split":
# return_dict = get_argument(args, kwargs, signature, "return_dict")
if config.vision_encoder:
# pixel_values = get_argument(args, kwargs, signature, "pixel_values")
image_positional_embeddings = model.get_image_wide_positional_embeddings()
# repeat with batch size
batch_size = pixel_values.shape[0]
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_outputs = model.vision_encoder(
pixel_values,
output_attentions=False,
output_hidden_states=False,
return_dict=return_dict,
)
image_embeddings = vision_outputs[0]
if not return_dict:
return (image_embeddings, image_positional_embeddings)
else:
return {
"image_embeddings": image_embeddings,
"image_positional_embeddings": image_positional_embeddings,
}
else:
if input_points is None:
raise ValueError("input_points is required to export the prompt encoder / mask decoder.")
sparse_embeddings, dense_embeddings = model.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=None, # Not supported in the ONNX export
input_masks=None, # Not supported in the ONNX export
)
low_res_masks, iou_predictions, _ = model.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True, # Not supported in the ONNX export
attention_similarity=None, # Not supported in the ONNX export
target_embedding=None, # Not supported in the ONNX export
output_attentions=False,
)
if not return_dict:
return (iou_predictions, low_res_masks)
else:
return {"iou_scores": iou_predictions, "pred_masks": low_res_masks}
self.patched_forward = patched_forward
def patched_speecht5_prenet_forward(
self,
input_values: torch.Tensor,
speaker_embeddings: Optional[torch.Tensor] = None,
):
# Dropout is always applied, even when evaluating. See §2.2 in https://arxiv.org/abs/1712.05884.
inputs_embeds = input_values
for layer in self.layers:
inputs_embeds = torch.nn.functional.relu(layer(inputs_embeds))
# NOTE: we patch the prenet to avoid using torch.nn.functional.dropout, that is exported as a `Dropout` node in the ONNX
# that is ignored during inference by some runtimes as ONNX Runtime.
# Reference: https://github.com/microsoft/onnxruntime/issues/9333 & https://github.com/microsoft/onnxruntime/issues/5549
mask = torch.rand(inputs_embeds.shape, device=inputs_embeds.device) > self.config.speech_decoder_prenet_dropout
inputs_embeds = inputs_embeds * mask / (1 - self.config.speech_decoder_prenet_dropout)
# inputs_embeds = nn.functional.dropout(
# inputs_embeds, self.config.speech_decoder_prenet_dropout, training=True
# )
inputs_embeds = self.final_layer(inputs_embeds)
inputs_embeds = self.encode_positions(inputs_embeds)
if speaker_embeddings is not None:
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings)
speaker_embeddings = speaker_embeddings.unsqueeze(1)
speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1)
inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
inputs_embeds = torch.nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))
return inputs_embeds
class SpeechT5ModelPatcher(ModelPatcher):
def __enter__(self):
self.patch_ops()
self._model.speecht5.decoder.prenet.forward = types.MethodType(
patched_speecht5_prenet_forward, self._model.speecht5.decoder.prenet
)
setattr(self._model, self.orig_forward_name, self.patched_forward)
def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
setattr(self._model, self.orig_forward_name, self.orig_forward)
self._model.speecht5.decoder.prenet.forward = types.MethodType(
self.original_speecht5_prenet_forward, self._model.speecht5.decoder.prenet
)
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
self.original_speecht5_prenet_forward = model.speecht5.decoder.prenet.forward
model.vocoder = model_kwargs["vocoder_model"].eval()
def patched_forward(
input_ids=None,
speaker_embeddings=None,
encoder_outputs=None,
past_key_values=None,
output_sequence=None,
spectrogram=None,
encoder_attention_mask=None,
):
use_cache = self.real_config.use_past and self.real_config.variant == "with-past"
if self.real_config._behavior == "encoder":
encoder_attention_mask = torch.ones_like(input_ids)
encoder_out = model.speecht5.encoder(
input_values=input_ids,
attention_mask=encoder_attention_mask,
return_dict=True,
)
# downsample encoder attention mask
if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet):
encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask(
encoder_out[0].shape[1], encoder_attention_mask
)
result = {
"encoder_outputs": encoder_out.last_hidden_state,
"encoder_attention_mask": encoder_attention_mask,
}
elif self.real_config._behavior == "decoder":
# TODO: and self.real_config.use_past_in_inputs
encoder_hidden_states = encoder_outputs[0]
decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
# Run the decoder layers on the last element of the prenet output.
decoder_out = model.speecht5.decoder.wrapped_decoder(
hidden_states=decoder_hidden_states[:, -1:],
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=False,
return_dict=True,
)
last_decoder_output = decoder_out.last_hidden_state[0, -1]
past_key_values = decoder_out.past_key_values
# Predict the new mel spectrum for this step in the sequence.
spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
spectrum = spectrum.view(model.config.reduction_factor, model.config.num_mel_bins)
# NOTE: extending the spectrogram should is to be handled outside of the ONNX.
# spectrogram.append(spectrum)
# Extend the output sequence with the new mel spectrum.
output_sequence = torch.cat(
(output_sequence, spectrum[-1].view(1, 1, model.config.num_mel_bins)), dim=1
)
# Predict the probability that this is the stop token.
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))
result = {
"output_sequence_out": output_sequence,
"spectrum": spectrum,
"prob": prob,
"past_key_values": past_key_values,
}
elif self.real_config.is_postnet_and_vocoder:
# NOTE: the following concatenation is expected to be handled outside of the ONNX:
# spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0)
spectrogram = spectrogram.unsqueeze(0)
spectrogram = model.speech_decoder_postnet.postnet(spectrogram)
spectrogram = spectrogram.squeeze(0)
waveform = model.vocoder(spectrogram)
result = {"waveform": waveform}
else:
raise ValueError("Should not happen")
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
for name, value in result.items():
if name != "past_key_values":
filterd_outputs[name] = value
else:
if self.real_config._behavior == "decoder" and (
self.real_config.is_merged or not self.real_config.use_past_in_inputs
):
filterd_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])
return filterd_outputs
self.patched_forward = patched_forward
class SentenceTransformersTransformerPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if (
is_transformers_version(">=", "4.42")
and is_transformers_version("<", "4.48")
and self.real_config._config.model_type == "mistral"
):
self._model[0].auto_model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model[0].auto_model
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if (
is_transformers_version(">=", "4.42")
and is_transformers_version("<", "4.48")
and self.real_config._config.model_type == "mistral"
):
self._model[0].auto_model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model[0].auto_model
)
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
if (
is_transformers_version(">=", "4.42")
and is_transformers_version("<", "4.48")
and self.real_config._config.model_type == "mistral"
):
self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask
def patched_forward(input_ids, attention_mask):
result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask})
if "input_ids" in result:
del result["input_ids"]
if "attention_mask" in result:
del result["attention_mask"]
if "all_layer_embeddings" in result:
del result["all_layer_embeddings"]
return result
self.patched_forward = patched_forward
class SentenceTransformersCLIPPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
def patched_forward(input_ids, attention_mask, pixel_values):
vision_outputs = model[0].model.vision_model(pixel_values=pixel_values)
image_embeds = model[0].model.visual_projection(vision_outputs[1])
text_outputs = model[0].model.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
)
text_embeds = model[0].model.text_projection(text_outputs[1])
if len(model) > 1:
image_embeds = model[1:](image_embeds)
text_embeds = model[1:](text_embeds)
return {"text_embeds": text_embeds, "image_embeds": image_embeds}
self.patched_forward = patched_forward
# Triu with possible dynamic `diagonal` argument. Not possible with torch.triu unfortunately.
def triu_onnx(x, diagonal=0):
l, w = x.shape
arange_rows = torch.arange(l, device=x.device)
arange_cols = torch.arange(w, device=x.device)
mask = arange_cols.expand(l, w)
arange_rows = arange_rows[:, None] + diagonal
mask = mask >= arange_rows
return x.masked_fill(mask == 0, 0)
def patched_build_delay_pattern_mask(self, input_ids: torch.Tensor, pad_token_id: int, max_length: int = None):
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * channel_codebooks - 1:
raise NotImplementedError("Not supported in ONNX export. Please open an issue in Optimum repository.")
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(channel_codebooks):
if self.config.audio_channels == 1:
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
else:
# left/right channels are interleaved in the generated codebooks, so handle one then the other
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int8 leads to `Could not find an implementation for Where`
delay_pattern = triu_onnx(
torch.ones((channel_codebooks, max_length), dtype=torch.int32), diagonal=max_length - channel_codebooks + 1
)
# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int32 leads to `Could not find an implementation for Trilu`, hence int64 here
# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.int64))
delay_pattern = delay_pattern.to(torch.bool)
if self.config.audio_channels == 2:
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
# TODO: Is this OK?
first_start_id = start_ids.min()
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids_edited = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return {"input_ids_edited": input_ids_edited, "delay_pattern_mask": pattern_mask}
class MusicgenModelPatcher(Seq2SeqModelPatcher):
def __enter__(self):
self.patch_ops()
if self.real_config.model_part == "build_delay_pattern_mask":
# For build_delay_pattern_mask, we need to override the signature too.
self._model.forward = types.MethodType(patched_build_delay_pattern_mask, self._model)
else:
setattr(self._model, self.orig_forward_name, self.patched_forward)
def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
if self.real_config.model_part == "build_delay_pattern_mask":
self._model.forward = self.original_decoder_forward
else:
setattr(self._model, self.orig_forward_name, self.orig_forward)
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
if config.model_part == "build_delay_pattern_mask":
self.original_decoder_forward = self.orig_forward
elif config.model_part == "encodec_decode":
# EncodecModel.forward -> EncodecModel.decode
@functools.wraps(self.orig_forward)
def patched_forward(
input_values: Optional["torch.Tensor"] = None,
padding_mask: Optional["torch.Tensor"] = None,
audio_codes: Optional["torch.Tensor"] = None,
bandwidth: Optional[float] = None,
audio_scales: Optional["torch.Tensor"] = None,
return_dict: Optional[bool] = None,
):
chunk_length = self.real_config._config.audio_encoder.chunk_length
if chunk_length is None:
if audio_scales is not None:
audio_scales = audio_scales[0]
if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._model._decode_frame(audio_codes[0], audio_scales)
else:
raise ValueError("Not supported, a meaningful error should have been raised ahead.")
decoded_frames = []
for frame, scale in zip(audio_codes, audio_scales):
frames = self._model._decode_frame(frame, scale)
decoded_frames.append(frames)
audio_values = self._model._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1)
# truncate based on padding mask
if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
audio_values = audio_values[..., : padding_mask.shape[-1]]
return {"audio_values": audio_values}
self.patched_forward = patched_forward
def _update_causal_mask_patched(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values,
use_cache: bool,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
# ---------------- NOTE: This part is patched -----------------------------
exclude_mask = torch.bitwise_or(
exclude_mask,
torch.arange(target_length, device=device)
<= (cache_position.reshape(-1, 1) - self.config.sliding_window),
)
# ---------------- NOTE: patch end ----------------------------------------
causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class MistralModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model.model
)
else:
self._model._update_causal_mask = types.MethodType(_update_causal_mask_patched, self._model)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model.model
)
else:
self._model._update_causal_mask = types.MethodType(self._update_causal_mask_original, self._model)
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"):
if hasattr(self._model, "model"):
self._update_causal_mask_original = self._model.model._update_causal_mask
else:
self._update_causal_mask_original = self._model._update_causal_mask
class CLIPModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.43") and is_transformers_version("<", "4.48"):
self.original_sdpa_forward = CLIPSdpaAttention.forward
CLIPSdpaAttention.forward = CLIPAttention.forward
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.43") and is_transformers_version("<", "4.48"):
CLIPSdpaAttention.forward = self.original_sdpa_forward
class VitPoseModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
# Set dataset_index (defaulting to COCO=0), otherwise we will get an error like:
# ValueError: dataset_index must be provided when using multiple experts (num_experts=6). Please provide dataset_index to the forward pass.
if model.config.backbone_config.num_experts > 1:
model_kwargs["dataset_index"] = torch.tensor(0, device=model.device)
super().__init__(config, model, model_kwargs)