optimum/exporters/openvino/model_patcher.py (4,787 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import logging as log
import math
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
from transformers.utils import is_tf_available
from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.model_patcher import (
UNSUPPORTED_OPS_PATCHING_SPEC,
DecoderModelPatcher,
ModelPatcher,
PatchingSpec,
Seq2SeqModelPatcher,
override_arguments,
)
from optimum.intel.utils.import_utils import (
_openvino_version,
_torch_version,
_transformers_version,
is_diffusers_version,
is_openvino_version,
is_torch_version,
is_transformers_version,
)
if TYPE_CHECKING:
from transformers.cache_utils import Cache
from transformers.modeling_utils import PreTrainedModel
from optimum.exporters.onnx.config import OnnxConfig
if is_tf_available():
from transformers.modeling_tf_utils import TFPreTrainedModel
def ov_compatible_repeat_interleave(input_tensor, repeats, dim=None, output_size=None):
"""
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.
Args:
input_tensor (torch.Tensor): The input tensor.
repeats (int or torch.Tensor): The number of repetitions for each element.
dim (int, optional): The dimension along which to repeat. Defaults to None.
Returns:
torch.Tensor: The repeated tensor.
"""
result = torch.repeat_interleave(input_tensor, repeats=repeats, dim=dim)
return result
def patch_unsupported_ops():
spec_idx = -1
for idx, spec in enumerate(UNSUPPORTED_OPS_PATCHING_SPEC):
if spec.name == "repeat_interleave":
spec_idx = idx
break
repreate_interlive_spec = PatchingSpec(
torch.Tensor, "repeat_interleave", ov_compatible_repeat_interleave, torch.Tensor.repeat_interleave
)
if spec_idx != -1:
UNSUPPORTED_OPS_PATCHING_SPEC[spec_idx] = repreate_interlive_spec
else:
UNSUPPORTED_OPS_PATCHING_SPEC.append(repreate_interlive_spec)
BETTERTRANSFORMER_IGNORE = [
"codegen",
]
# in transformers 4.45 gpt_neo has SDPA
if is_transformers_version(">=", "4.44.99"):
BETTERTRANSFORMER_IGNORE.append("gpt_neo")
patch_unsupported_ops()
def patch_model_with_bettertransformer(model):
COLOR_RED = "\033[1;31m"
COLOR_RESET = "\033[0m"
# check that the model has not yet been pathced
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
return model
if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"):
log.warning(
COLOR_RED
+ "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. "
f"This Python environment has Transformers {_transformers_version} and PyTorch {_torch_version}. "
"Consider upgrading PyTorch and Transformers, for example by running "
"`pip install --upgrade --upgrade-strategy eager optimum[openvino]`, and export the model again"
+ COLOR_RESET
)
if (
getattr(model.config, "model_type") in {"gpt_bigcode", "llama", "gemma"}
and is_transformers_version(">=", "4.38")
and is_openvino_version("<", "2024.1.0-14612")
):
# display commit-id only when a nightly/prerelease of OpenVINO is installed.
display_version = (
_openvino_version.split("-")[0] if is_openvino_version("<=", "2024.0.0-14509") else _openvino_version
)
log.warning(
COLOR_RED
+ f"[WARNING] Stateful models are not supported for Llama, Gemma and GPTBigCode with Transformers "
f"{_transformers_version} and OpenVINO {display_version}. For good performance, consider using a nightly OpenVINO build: "
"https://docs.openvino.ai/2024/get-started/install-openvino.html. For gpt-bigcode and llama models, "
"it is also an option to downgrade transformers: `pip install transformers==4.37.2`" + COLOR_RESET
)
# model already has required SDPA implementation
if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa":
return model
if model.config.model_type in BETTERTRANSFORMER_IGNORE:
return model
try:
model = model.to_bettertransformer()
except Exception as e:
log.warning(
f"Cannot apply model.to_bettertransformer because of the exception:\n{e}."
" Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention"
)
return model
return model
def patch_update_causal_mask(
model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False
):
if is_transformers_version(">=", transformers_version):
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
if inner_model is not None:
if hasattr(inner_model, "_update_causal_mask"):
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
patch_fn = patch_fn or _llama_gemma_update_causal_mask
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)
def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False):
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
if inner_model is not None and hasattr(inner_model, "_orig_update_causal_mask"):
inner_model._update_causal_mask = inner_model._orig_update_causal_mask
# initialization of sin/cos cached in bf16/fp16 leads to accuracy loss
# reinitialize them to save in float32 before export
def _reinitialize_cos_sin_cached_fp32(rotary_emb):
if rotary_emb.cos_cached.dtype != torch.float32:
rotary_emb._set_cos_sin_cache(
seq_len=rotary_emb.max_position_embeddings, device=rotary_emb.inv_freq.device, dtype=torch.float32
)
def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
if is_transformers_version("<", "4.37.0"):
current_hidden_states = expert_layer(current_state, routing_weights[top_x, idx, None])
else:
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
class MixtralModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")
for layer in self._model.model.layers:
layer.block_sparse_moe._unpatched_forward = layer.block_sparse_moe.forward
layer.block_sparse_moe.forward = types.MethodType(
_mixtral_sparse_moe_block_forward, layer.block_sparse_moe
)
if is_transformers_version("<", "4.44.99"):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
for layer in self._model.model.layers:
layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward
class ArcticModelPatcher(MixtralModelPatcher):
def __enter__(self):
# model initialize some weights for matrix multiplication in bfloat16, that lead to inconsistency of dtype
try:
self._model.to(torch.float32)
except Exception:
pass
super().__enter__()
def _chatglm_transformer_forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if getattr(self, "pre_seq_len", None) is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)
if full_attention_mask is None:
if past_key_values is not None:
full_attention_mask = torch.ones(
batch_size,
seq_length,
seq_length,
device=input_ids.device,
dtype=torch.float,
) * float("-inf")
full_attention_mask.triu_(diagonal=1)
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[0]
if past_length:
full_attention_mask = torch.cat(
(
torch.zeros(batch_size, seq_length, past_length, device=input_ids.device),
full_attention_mask,
),
dim=-1,
)
full_attention_mask.unsqueeze_(1)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor):
mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype)
if query_layer.shape[2] == key_layer.shape[2]:
tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1)
mask.masked_fill_(tmp_mask, float("-inf"))
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attn_mask=mask
)
return context_layer
def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None:
context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer)
else:
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer
def _glm4_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
causal_mask.masked_fill_(attention_mask, float("-inf"))
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, causal_mask)
context_layer = context_layer.transpose(1, 2).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer
class ChatGLMModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
self.is_v4 = hasattr(self._model.config, "rope_ratio")
def __enter__(self):
super().__enter__()
if not self.is_v4:
self._model.transformer._orig_forward = self._model.transformer.forward
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)
for block in self._model.transformer.encoder.layers:
block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward
block.self_attention.core_attention.forward = types.MethodType(
_chatglm2_core_attention_forward if not self.is_v4 else _glm4_core_attention_forward,
block.self_attention.core_attention,
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.transformer, "_orig_forward"):
self._model.transformer.forward = self._model.transformer._orig_forward
for block in self._model.transformer.encoder.layers:
block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward
# adopted from
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None:
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
# difference with original modeling
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
# during execution on platforms with default lower precision (bfloat16, float16)
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
# difference with original modeling
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
if past_seen_tokens is not None:
current_length = past_seen_tokens + sequence_length + 1
# TODO : remove after support of transformers >= v4.40.0
else:
current_length = cache_position[-1] + 1
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length
# difference with original modeling
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
# adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036
def _llama_gemma_update_causal_mask_latest(
self,
attention_mask,
input_tensor,
cache_position,
past_key_values,
output_attentions,
):
from transformers.cache_utils import StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
# difference with original modeling
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
# during execution on platforms with default lower precision (bfloat16, float16)
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
# difference with original modeling
causal_mask = (
torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
# TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0
if is_transformers_version(">", "4.40.2"):
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest
else:
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None):
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L104
_seq_len = torch.max(position_ids) + 1 if seq_len is None else seq_len
if _seq_len > self.embed_positions.shape[0]:
if seq_len is None:
return self._orig_forward(x, position_ids)
else:
return self._orig_forward(x, position_ids, seq_len)
sincos = self.embed_positions[position_ids]
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
return cos, sin
def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000, inv_freq=None) -> torch.Tensor:
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
if inv_freq is None:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1)
return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
def register_sin_cos_buffer(model):
max_positions = model.config.max_position_embeddings
# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
# use precomputed
rotary_emb = model.model.layers[0].self_attn.rotary_emb
dim, base = None, None
inv_freq = getattr(rotary_emb, "inv_freq", None)
if inv_freq is None:
base = rotary_emb.base
dim = rotary_emb.dim
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)
for layer in model.model.layers:
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward
layer.self_attn.rotary_emb.forward = types.MethodType(
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
)
class LlamaModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
# llama/gemma has some accuracy issues with bf16 with transformers >= 4.39
# fill causal mask in slightly different way for avoid overflow on some platforms
patch_update_causal_mask(self._model, "4.39.0", "model" if hasattr(self._model, "model") else "transformer")
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "model" if hasattr(self._model, "model") else "transformer")
# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
def _mistral_update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: "Cache",
use_cache: bool,
output_attentions: bool,
):
from transformers.cache_utils import SlidingWindowCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
exclude_mask = exclude_mask.bitwise_or(
torch.arange(target_length, device=device)
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
)
causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class MistralModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
# apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)
else:
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "rotary_emb"):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "rotary_emb") and hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
SUPPORT_SDPA = is_torch_version(">", "2.1.0")
def _qwen_rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def _qwen_apply_rotary_pos_emb(t, freqs):
cos, sin = freqs
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * cos) + (_qwen_rotate_half(t_) * sin)
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
def _qwen_quantize_cache_v(fdata, bits, qmax, qmin):
# b, s, head, h-dim->b, head, s, h-dim
qtype = torch.uint8
device = fdata.device
shape = fdata.shape
fdata_cal = torch.flatten(fdata, 2)
fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
# Compute params
if qmax.device != fmax.device:
qmax = qmax.to(device)
qmin = qmin.to(device)
scale = (fmax - fmin) / (qmax - qmin)
zero = qmin - fmin / scale
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
# Quantize
res_data = fdata / scale + zero
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
return qdata.contiguous(), scale, zero
def _qwen_attention_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb_list is not None:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = _qwen_apply_rotary_pos_emb(query, q_pos_emb)
key = _qwen_apply_rotary_pos_emb(key, k_pos_emb)
else:
query_list = []
key_list = []
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [_qwen_apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
key_list += [_qwen_apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
if self.use_cache_quantization:
key = _qwen_quantize_cache_v(key.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax)
value = _qwen_quantize_cache_v(value.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
if self.use_cache_quantization:
# use_cache_quantization:
# present=((q_key,key_scale,key_zero_point),
# (q_value,value_scale,value_zero_point))
key = (
torch.cat((past_key[0], key[0]), dim=2),
torch.cat((past_key[1], key[1]), dim=2),
torch.cat((past_key[2], key[2]), dim=2),
)
value = (
torch.cat((past_value[0], value[0]), dim=2),
torch.cat((past_value[1], value[1]), dim=2),
torch.cat((past_value[2], value[2]), dim=2),
)
else:
# not use_cache_quantization:
# present=(key,value)
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value)
else:
present = None
if self.use_logn_attn and not self.training:
if self.use_cache_quantization:
seq_start = key[0].size(2) - query.size(1)
seq_end = key[0].size(2)
else:
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query)
if self.use_flash_attn and not self.is_fp32 and query.is_cuda:
q, k, v = query, key, value
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
else:
registered_causal_mask = torch.tril(
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
).view(1, 1, key.size(1), key.size(1))
query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
if not self.use_cache_quantization and SUPPORT_SDPA:
# For performance, using constant tril to generate causal_mask
causal_mask = self.bias[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)]
if attention_mask is not None:
attention_mask = attention_mask.expand(-1, -1, query.size(2), -1).masked_fill(
~causal_mask, torch.finfo(query.dtype).min
)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
attn_weight = None
else:
attn_output, attn_weight = self._attn(query, key, value, registered_causal_mask, attention_mask, head_mask)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
if output_attentions:
if self.use_flash_attn and not self.is_fp32:
raise ValueError("Cannot output attentions while using flash-attn")
else:
outputs += (attn_weight,)
return outputs
class QwenModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
self.original_fp16 = model.config.fp16
self.original_bf16 = model.config.bf16
model.config.bf16 = False
model.config.fp16 = False
if self.original_fp16 or self.original_bf16:
# GPTQ models does to support casting to dtype
try:
model.to(torch.float32)
except Exception:
pass
model.transformer.rotary_emb(2048)
def __enter__(self):
super().__enter__()
max_positions = self._model.config.seq_length
for block in self._model.transformer.h:
block.attn._orig_forward = block.attn.forward
# For performance, using constant tril to generate causal_mask
block.attn.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.transformer.h:
block.attn.forward = block.attn._orig_forward
self._model.config.bf16 = self.original_bf16
self._model.config.fp16 = self.original_fp16
def _baichuan13b_atten_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
# reuse k, v, self_attention
if attention_mask is not None:
attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :]
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if not output_attentions:
past_key_value = (key_states, value_states) if use_cache else None
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
# Adapted from https://huggingface.co/baichuan-inc/Baichuan-7B/blob/262c8cb58b6d3615c208d9230baa869fddee2adb/modeling_baichuan.py#L181
def _baichuan7b_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if not output_attentions:
attn_weights = None
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, scale=1 / math.sqrt(self.head_dim)
)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
class BaichuanModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
# model has first inference buffers initialization
if hasattr(self._model.lm_head, "first_flag"):
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
def __enter__(self):
super().__enter__()
# override signature to have position_ids
if "position_ids" not in inspect.signature(self._model.forward).parameters:
self._model._orig_forward = self._model.forward
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
position_ids: Optional[torch.LongTensor] = None,
):
return self._orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=past_key_values is not None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=self.config.return_dict,
)
self._model.forward = types.MethodType(forward, self._model)
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn)
else:
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_baichuan7b_attn_forward, layer.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model, "_orig_forward"):
self._model.forward = self._model._orig_forward
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
def _mpt_sdpa_attention_forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
):
batch_size, seq_length = hidden_states.shape[:2]
mixed_qkv = self.Wqkv(hidden_states)
query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
if len(past_key_value) != 0:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
else:
past_key_value = (key_states, value_states)
key_length = key_states.shape[-2]
query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
attention_mask_sdpa = torch.ones(
(query_states.shape[0], query_states.shape[1], query_states.shape[2], key_states.shape[2]),
dtype=query_states.dtype,
)
if position_bias is not None:
position_bias_query_index = max(0, position_bias.size(1) - query_length)
position_bias_key_index = max(0, position_bias.size(2) - key_length)
position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]
attention_mask_sdpa += position_bias
attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min)
context_states = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask_sdpa,
dropout_p=self.attn_dropout_p,
scale=self.softmax_scale,
)
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
attn_output = self.out_proj(context_states)
return attn_output, None, past_key_value
def _mpt_block_forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.norm_1(hidden_states)
residual = hidden_states
if not output_attentions:
# Self attention.
attn_outputs, attn_weights, past_key_value = self.attn(
layernorm_output,
position_bias=position_bias,
attention_mask=attention_mask,
past_key_value=layer_past,
)
else:
attn_outputs, attn_weights, past_key_value = self.attn._orig_forward(
layernorm_output,
position_bias=position_bias,
attention_mask=attention_mask,
past_key_value=layer_past,
)
hidden_states = self.resid_attn_dropout(attn_outputs) + residual
layernorm_output = self.norm_2(hidden_states)
# Get residual
residual = hidden_states
# MLP.
output = self.ffn(layernorm_output, residual)
outputs = (output,)
if use_cache:
outputs += (past_key_value,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class MPTModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_torch_version(">=", "2.1.0"):
for block in self._model.transformer.blocks:
block._orig_forward = block.forward
block.forward = types.MethodType(_mpt_block_forward, block)
block.attn._orig_forward = block.attn.forward
block.attn.forward = types.MethodType(_mpt_sdpa_attention_forward, block.attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.transformer.blocks:
if hasattr(block, "_orig_forward"):
block.forward = block._orig_forward
if hasattr(block.attn, "_orig_forward"):
block.attn.forward = block.attn._orig_forward
def _internlm2_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from einops import rearrange
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
if position_ids is not None:
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
bsz, q_len, _ = hidden_states.size()
qkv_states = self.wqkv(hidden_states)
qkv_states = rearrange(
qkv_states,
"b q (h gs d) -> b q h gs d",
gs=2 + self.num_key_value_groups,
d=self.head_dim,
)
query_states = qkv_states[..., : self.num_key_value_groups, :]
query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
is_legacy = not hasattr(self, "layer_idx")
if is_legacy:
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
else:
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": kwargs.get("cache_position")}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)
return attn_output, attn_weights, past_key_value
class InternLM2Patcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_torch_version(">=", "2.1.0"):
for block in self._model.model.layers:
block.attention._orig_forward = block.attention.forward
block.attention.forward = types.MethodType(_internlm2_attention_forward, block.attention)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
if hasattr(block.attention, "_orig_forward"):
block.attention.forward = block.attention._orig_forward
def phi3_442_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L729
def _phi3_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
if is_transformers_version(">=", "4.41.0"):
from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb, repeat_kv
else:
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class Phi3ModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
# currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues
if self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
self._model.model._orig_forward = self._model.model.forward
self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model)
# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
# init inv_freq for torchscript tracing
# 4.48 transformers version phi3 fixed, but issue still visible with trust_remote_true=True (trust_remote_code has _support_sdpa = False)
for layer in self._model.model.layers:
if (
is_torch_version(">=", "2.1.0")
and is_transformers_version("<", "4.48.0")
or not getattr(self._model, "_supports_sdpa", False)
):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd
if (
hasattr(layer.self_attn, "rotary_emb")
and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None
):
rotary_emb = layer.self_attn.rotary_emb
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_forward"):
self._model.model.forward = self._model.model._orig_forward
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756
# removed usage nonfriendly for tracing operation continue
def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
from transformers.models.phimoe.modeling_phimoe import sparsemixer
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.input_jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
)
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights, selected_experts = sparsemixer(
router_logits,
jitter_eps=self.router_jitter_noise,
training=self.training,
)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# if top_x.shape[0] == 0:
# continue
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
class PhiMoEModelPatcher(Phi3ModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward
layer.block_sparse_moe.forward = types.MethodType(
_phi_moe_sparse_moe_block_forward, layer.block_sparse_moe
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
layer.block_sparse_moe.forward = layer.block_sparse_moe._orig_forward
def _aquila_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)
bsz, q_len, _ = hidden_states.size()
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if hasattr(self, "num_key_value_groups"):
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
class AquilaModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_aquila_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
def _xverse_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def _internlm_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
)
attn_weights = None
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
class XverseModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_xverse_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
class InternLMModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_internlm_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
# Adapted from https://github.com/huggingface/optimum/blob/3adbe7c75e3c41c1a3b945cf085e74ece7f8e192/optimum/bettertransformer/models/attention.py#L234
def codegen_wrapped_scaled_dot_product(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
batch_size = query.shape[0]
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)
# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)
dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
# first step of the decoding
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
if not is_transformers_version(">", "4.44.99"):
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(
torch.bool
)
causal_mask = torch.where(causal_mask, 0, mask_value)
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
else:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)
return sdpa_result, None
# copied from https://github.com/huggingface/optimum/blob/2112e99122d7f23a1da1a9d263fef64301050ea7/optimum/bettertransformer/models/attention.py#L168
# for preserving backward compatibility between outdated codegen remote code and new transformers
def _codegen_wrapped_scaled_dot_product_legacy(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
if head_mask is not None:
raise ValueError("`head_mask` input argument is not supported")
batch_size = query.shape[0]
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")
# in codegen the query and key are always in fp32 regardless of the dtype of the model
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
query = query.to(value.dtype)
key = key.to(value.dtype)
dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
# first step of the decoding
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask is always [True, ..., True] otherwise, so executing this is unnecessary
if query_length > 1:
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = torch.where(causal_mask, 0, mask_value)
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)
return sdpa_result, None
class CodeGenModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
attn_fn = codegen_wrapped_scaled_dot_product
if is_torch_version(">=", "2.1.0") and is_transformers_version(">=", "4.45"):
# in transformers 4.45 causal_mask const buffer was removed from the model
# if it still exists, it means legacy remote code was loaded
if hasattr(self._model.transformer.h[0].attn, "causal_mask"):
attn_fn = _codegen_wrapped_scaled_dot_product_legacy
for layer in self._model.transformer.h:
if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions:
orig_self_attn_fwd = layer.attn._attn
layer.attn._attn = types.MethodType(attn_fn, layer.attn)
layer.attn._orig_attn = orig_self_attn_fwd
patch_update_causal_mask(self._model, "4.45.0", "transformer")
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "transformer")
for layer in self._model.transformer.h:
if hasattr(layer.attn, "_orig_attn"):
layer.attn._attn = layer.attn._orig_attn
# Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763
def _dbrx_experts_forward(
self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
):
bsz, q_len, hidden_size = x.shape
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)
expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
# Chunk experts at once to avoid storing full parameter multiple times in autograd
w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
self.moe_num_experts, dim=0
)
v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
self.moe_num_experts, dim=0
)
w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
self.moe_num_experts, dim=0
)
w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
# Difference with original: removal
# if token_idx.shape[0] == 0:
# continue
# loop interruption depends on input data and may affect torchscript tracing
token_list = token_idx
topk_list = topk_idx
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
expert_out = (
self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
* top_weights[token_list, topk_list, None]
)
out.index_add_(0, token_idx, expert_out)
out = out.reshape(bsz, q_len, hidden_size)
return out
# Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228
def _dbrx_update_causal_mask_legacy(
self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor
) -> Optional[torch.Tensor]:
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
dtype, device = input_tensor.dtype, input_tensor.device
# difference with original modeling
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
# during execution on platforms with default lower precision (bfloat16, float16)
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
)
# difference with original modeling
# removed target_length = int(target_length).
# Casting to int leads to constant folding during tracing that makes impossible to use model for sequence of different length
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
# adopted from https://github.com/huggingface/transformers/blob/1b3dba9417eebe16b7c206d1dfca6a4c7f11dbec/src/transformers/models/dbrx/modeling_dbrx.py#L1204
def _dbrx_update_causal_mask_latest(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: "Cache",
output_attentions: bool,
):
from transformers.cache_utils import StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
# difference with original modeling
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
# during execution on platforms with default lower precision (bfloat16, float16)
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
# difference with original modeling
causal_mask = (
torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
if is_transformers_version(">", "4.40.2"):
_dbrx_update_causal_mask = _dbrx_update_causal_mask_latest
else:
_dbrx_update_causal_mask = _dbrx_update_causal_mask_legacy
class DBRXModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
# dbrx has some accuracy issues with bf16 with transformers >= 4.40
# fill causal mask in slightly different way for avoid overflow on some platforms
self._model.transformer._orig_update_causal_mask = self._model.transformer._update_causal_mask
self._model.transformer._update_causal_mask = types.MethodType(
_dbrx_update_causal_mask, self._model.transformer
)
# starting from transformers 4.41 issue also observable for calculation sin/cos for rotary_emb
patch_rope_sin_cos = is_transformers_version(">=", "4.41.0")
inv_freq = getattr(self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb, "inv_freq")
dim, base = None, None
if inv_freq is None:
dim = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.dim
base = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.base
max_positions = self._model.config.max_seq_len
if patch_rope_sin_cos:
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)
for block in self._model.transformer.blocks:
rotary_emb = block.norm_attn_norm.attn.rotary_emb
# initialize inv_freq for torchscript tracing
if rotary_emb.inv_freq is None:
inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)
rotary_emb.inv_freq = inv_freq
if patch_rope_sin_cos:
rotary_emb.register_buffer("embed_positions", embed_positions)
rotary_emb._orig_forward = rotary_emb.forward
rotary_emb.forward = types.MethodType(llama_gemma_rotary_emb_forward, rotary_emb)
# remove continue-operator from iteration loop over experts
block.ffn.experts._orig_forward = block.ffn.experts.forward
block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask
for block in self._model.transformer.blocks:
block.ffn.experts.forward = block.ffn.experts._orig_forward
if hasattr(block.norm_attn_norm.attn.rotary_emb, "_orig_forward"):
block.norm_attn_norm.attn.rotary_emb.forward = block.norm_attn_norm.attn.rotary_emb._orig_forward
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264
def _persimmon_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from transformers.models.persimmon.modeling_persimmon import apply_rotary_pos_emb
if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)
bsz, q_len, _ = hidden_states.size()
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
query_states = query_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
if is_transformers_version("<", "4.44.99"):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
else:
if position_embeddings is None:
log.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
if is_transformers_version("<", "4.44.99"):
rotary_ndims = self.rotary_emb.dim
else:
rotary_ndims = self.rotary_ndims
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., :rotary_ndims],
query_states[..., rotary_ndims:],
)
key_rot, key_pass = (
key_states[..., :rotary_ndims],
key_states[..., rotary_ndims:],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None:
# Specific to RoPE models with partial rotation
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": rotary_ndims,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
causal_mask,
scale=1 / math.sqrt(self.head_dim),
dropout_p=self.attention_dropout.p,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.dense(attn_output)
return attn_output, None, past_key_value
class PersimmonModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_persimmon_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd
if is_transformers_version("<", "4.44.99"):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
def _jais_attn_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
position_bias: Optional[torch.FloatTensor] = None,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `JAISAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, attention_mask, head_mask, position_bias
)
else:
# Difference with original: override attn realization with sdpa if not output_attentions
if not output_attentions:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, position_bias)
else:
attn_output, attn_weights = self._orig_attn(query, key, value, attention_mask, head_mask, position_bias)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
def _jais_attn(self, query, key, value, attention_mask=None, head_mask=None, position_bias=None):
scale = 1.0
if self.scale_attn_weights:
scale = 1 / self.head_dim**self.attn_scale_power
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
scale = scale / float(self.layer_idx + 1)
query_length = query.size(-2)
attention_mask_sdpa = torch.ones(
(query.shape[0], query.shape[1], query.shape[2], key.shape[2]),
dtype=query.dtype,
)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(torch.float16).min
attention_mask_sdpa.masked_fill_(~causal_mask, mask_value)
if attention_mask is not None:
# Apply the attention mask
attention_mask_sdpa = attention_mask_sdpa + attention_mask
if position_bias is not None:
attention_mask_sdpa += position_bias.type_as(attention_mask_sdpa).unsqueeze(0)
# Mask heads if we want to
if head_mask is not None:
attention_mask_sdpa = attention_mask_sdpa * head_mask
attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask_sdpa, dropout_p=self.attn_dropout.p, scale=scale
)
return attn_output, None
class JaisModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.transformer.h:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.attn._attn
layer.attn._attn = types.MethodType(_jais_attn, layer.attn)
layer.attn._orig_attn = orig_self_attn_fwd
layer.attn._orig_forward = layer.attn.forward
layer.attn.forward = types.MethodType(_jais_attn_forward, layer.attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.transformer.h:
if hasattr(layer.attn, "_orig_attn"):
layer.attn._attn = layer.attn._orig_attn
layer.attn.forward = layer.attn._orig_forward
class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")
if (
hasattr(self._model, "model")
and hasattr(self._model.model, "layers")
and hasattr(self._model.model.layers[0].self_attn, "rotary_emb")
and hasattr(self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache")
):
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model)
class RotaryEmbPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version("<", "4.44.99"):
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
# Adapted from https://github.com/huggingface/transformers/blob/31f9a289a6207be6cae746e009d8e0db523be203/src/transformers/models/falcon/modeling_falcon.py#L1138
def _falcon_prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
# different from original: allow to provide min_dtype as parameter
min_dtype = torch.finfo(dtype).min if "min_dtype" not in kwargs else kwargs["min_dtype"]
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def _falcon_update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: "Cache",
output_attentions: bool,
head_mask: torch.Tensor,
alibi: torch.Tensor,
):
# copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130
from transformers.cache_utils import StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
and head_mask is None
and alibi is None
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
# difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution
min_dtype = torch.finfo(torch.float16).min
batch_size, sequence_length, _ = input_tensor.shape
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _falcon_prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
# We take care to integrate alibi bias in the causal_mask here
if head_mask is None and alibi is not None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
causal_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
causal_mask < -1,
min_dtype,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class FalconModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version("<", "4.44.99"):
for layer in self._model.transformer.h:
_reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb)
else:
patch_update_causal_mask(self._model, "4.45.0", "transformer", _falcon_update_causal_mask)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "transformer")
class GptNeoxModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version("<", "4.44.99"):
for layer in self._model.gpt_neox.layers:
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
else:
patch_update_causal_mask(self._model, "4.45.0", "gpt_neox")
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "gpt_neox")
# Adopted from https://github.com/huggingface/optimum/blob/v1.24.0/optimum/bettertransformer/models/attention.py#L96
def _gptj_attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
return self._orig_attn(query, key, value, attention_mask, head_mask)
batch_size = query.shape[0]
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)
# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if getattr(self, "downcast_qk", False):
query = query.to(value.dtype)
key = key.to(value.dtype)
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
return self._orig_attn(query, key, value, attention_mask, head_mask)
dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
if not is_transformers_version(">=", "4.44.99"):
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = torch.where(causal_mask, 0, mask_value)
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
if attention_mask is not None:
attention_mask = causal_mask + attention_mask
else:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)
# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if getattr(self, "downcast_qk", False):
sdpa_result = sdpa_result.to(value.dtype)
return sdpa_result, None
def gptj_attn_forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
):
if output_attentions:
self._attn = self._orig_attn
return self._orig_forward(
hidden_states,
layer_past,
attention_mask,
position_ids,
head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
class GptJModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.45.0", "transformer")
if is_transformers_version(">=", "4.49"):
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
for block in self._model.transformer.h:
block.attn._orig_forward = block.attn.forward
block.attn.forward = types.MethodType(gptj_attn_forward, block.attn)
block.attn._orig_attn = block.attn._attn
block.attn._attn = types.MethodType(_gptj_attn, block.attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "transformer")
if is_transformers_version(">=", "4.49"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for block in self._model.transformer.h:
block.attn.forward = block.attn._orig_forward
block.attn._attn = block.attn._orig_attn
class GptNeoxJapaneseModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version("<", "4.44.99"):
for layer in self._model.gpt_neox_japanese.layers:
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
else:
patch_update_causal_mask(self._model, "4.45.0", "gpt_neox_japanese")
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")
# Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L721
def _bloom_attn_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past=None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
from transformers.models.bloom.modeling_bloom import dropout_add
if head_mask is not None or output_attentions:
return self._orig_forward(
hidden_states,
residual,
alibi,
attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
batch_size, q_length, _ = hidden_states.shape
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)
# 3 x [batch_size, num_heads, seq_length, head_dim]
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
if layer_past is not None:
cache_kwargs = {"cache_position": cache_position}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
if attention_mask is not None: # no matter the length, we just slice it
kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1
causal_mask = attention_mask[:, :, :, :kv_length]
alibi = torch.masked_fill(alibi, causal_mask.bool(), torch.finfo(alibi.dtype).min)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)
# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(batch_size, q_length, self.hidden_size)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, layer_past)
return outputs
class BloomModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.49.0"):
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
for block in self._model.transformer.h:
block.self_attention._orig_forward = block.self_attention.forward
block.self_attention.forward = types.MethodType(_bloom_attn_forward, block.self_attention)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.49.0"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for block in self._model.transformer.h:
block.self_attention.forward = block.self_attention._orig_forward
def _gpt_neo_attn_forward(
self,
hidden_states,
attention_mask=None,
layer_past=None,
head_mask=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
if output_attentions:
self._attn = self._orig_attn
return self._orig_forward(
hidden_states,
attention_mask=attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
# Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
def _gpt_neo_attn_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
batch_size = query.shape[0]
mask_value = torch.finfo(torch.float16).min
mask_value = torch.full([], mask_value, dtype=value.dtype)
dropout_p = float(self.config.attention_dropout) if self.training else 0.0
if (batch_size == 1 or self.training) and self.attention_type == "global":
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=1.0
)
else:
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
causal_mask = torch.where(causal_mask, 0, mask_value)
if batch_size > 1:
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
if attention_mask is not None:
attention_mask = causal_mask + attention_mask
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False, scale=1.0
)
return sdpa_result, None
class GptNeoModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"):
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
for layer in self._model.transformer.h:
self_attn = layer.attn.attention
self_attn._orig_attn = self_attn._attn
self_attn._attn = types.MethodType(_gpt_neo_attn_sdpa, self_attn)
self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.transformer.h:
for layer in self._model.transformer.h:
layer.attn.attention.forward = layer.attn.attention._orig_forward
layer.attn.attention._attn = layer.attn.attention._orig_attn
class Gemma2ModelPatcher(LlamaModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
from transformers.cache_utils import DynamicCache
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
return_legacy_cache = False
pkv_in_args = False
legacy_pkv = None
if "past_key_values" in kwargs:
legacy_pkv = kwargs.pop("past_key_values", None)
sign_names = list(signature.parameters.keys())
pkv_argument_index = sign_names.index("past_key_values")
cache_position_index = sign_names.index("cache_position") if "cache_position" in sign_names else -1
input_ids_index = sign_names.index("input_ids" if "input_ids" in sign_names else "inputs_embeds")
if legacy_pkv is None and len(args) > pkv_argument_index:
legacy_pkv = args[pkv_argument_index]
pkv_in_args = True
if legacy_pkv is not None:
pkv = DynamicCache.from_legacy_cache(legacy_pkv)
return_legacy_cache = True
if not pkv_in_args:
kwargs["past_key_values"] = pkv
else:
args[pkv_argument_index] = pkv
if (
return_legacy_cache
and cache_position_index != -1
and (cache_position_index > len(args) and "cache_position" not in kwargs)
):
past_seen_tokens = legacy_pkv[0][0].shape[-2]
input_ids = args[input_ids_index] if "input_ids" not in kwargs else kwargs["input_ids"]
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device
)
kwargs["cache_position"] = cache_position
outputs = self.orig_forward(*args, **kwargs)
if return_legacy_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
return outputs
self.patched_forward = patched_forward
def _decilm_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# decilm contains bug in attention calculation for case if past key values is not None
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
bsz, q_len, _ = hidden_states.size()
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, is_causal=attention_mask is None, attn_mask=attention_mask
)
# modified, in original implementation .transpose(1, 2) missed
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
class DeciLMModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_decilm_attn_forward, layer.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
layer.self_attn.forward = layer.self_attn._orig_forward
class IBertModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
if getattr(self._model, "ibert"):
embeddings = self._model.ibert.embeddings
else:
embeddings = self._model.embeddings
# model has first inference buffers initialization, it may breaks tracing
if getattr(embeddings.LayerNorm, "dim_sqrt") is None:
self._model(torch.ones([1, 1], dtype=torch.long))
class InternVLChatImageEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = model.extract_feature
if model.vision_model.encoder.layers[0].attn.use_flash_attn:
for layer in model.vision_model.encoder.layers:
layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn
layer.attn.use_flash_attn = False
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"):
for layer in self._model.vision_model.encoder.layers:
layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn
class InternVL2ChatLangModelPatcher(DecoderModelPatcher):
def __init__(
self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any]
):
model_type = model.config.model_type
patcher_for_model_type = {
"llama": LlamaModelPatcher,
"qwen2": UpdateCausalMaskModelPatcher,
"phi3": Phi3ModelPatcher,
"internlm2": InternLM2Patcher,
}
self._internal_patcher = None
self._patched_forward = None
internal_patcher_cls = patcher_for_model_type.get(model_type)
if internal_patcher_cls is not None:
self._internal_patcher = internal_patcher_cls(config, model, model_kwargs)
self._patched_forward = self._internal_patcher.patched_forward
super().__init__(config, model, model_kwargs)
@property
def patched_forward(self):
if self._internal_patcher is not None:
return self._internal_patcher.patched_forward
return self._patched_forward
@patched_forward.setter
def patched_forward(self, fn):
self._patched_forward = fn
if self._internal_patcher is not None:
self._internal_patcher.patched_forward = fn
def __enter__(self):
if is_torch_version(">=", "2.1.0"):
if (
self._model.config.model_type in ["qwen2", "llama"]
and self._model.config._attn_implementation != "sdpa"
):
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
if self._model.config.model_type == "qwen2" and is_transformers_version("<", "4.48"):
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
if self._model.config.model_type == "llama" and is_transformers_version("<", "4.47"):
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"]
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
if self._internal_patcher is not None:
return self._internal_patcher.__enter__()
return super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
if self._internal_patcher:
self._internal_patcher.__exit__(exc_type, exc_value, traceback)
else:
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
def llava_vision_embed_forward(self, pixel_values):
# copied from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L428-L441
# these changes does not bring any difference from original, it only packs model subcomponent inference together
# that allow us avoid memory overheads and their inference results handling on code-level
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.config.vision_feature_layer]
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def llava_next_video_vision_embed_forward(self, pixel_values):
# copied from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/llava_next_video/modeling_llava_next_video.py#L519
# these changes does not bring any difference from original, it only packs model subcomponent inference together
# that allow us avoid memory overheads and their inference results handling on code-level
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
vision_feature_layer = self.config.vision_feature_layer
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
return selected_image_feature
# Modified from https://huggingface.co/microsoft/maira-2/blob/main/modeling_maira2.py#L68
def maira_vision_embed_forward(self, pixel_values):
vision_feature_select_strategy = self.config.vision_feature_select_strategy
vision_feature_layer = self.config.vision_feature_layer
return self.get_image_features(pixel_values, vision_feature_layer, vision_feature_select_strategy)
class LlavaImageEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(llava_vision_embed_forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
class MairaImageEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(maira_vision_embed_forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
class LlavaNextVideoImageEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(llava_next_video_vision_embed_forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
def _embednb_forward(self, ids: torch.Tensor) -> torch.Tensor:
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
class FluxTransfromerModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if is_diffusers_version("<", "0.31.0"):
self._model.pos_embed._orig_forward = self._model.pos_embed.forward
self._model.pos_embed.forward = types.MethodType(_embednb_forward, self._model.pos_embed)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.pos_embed, "_orig_forward"):
self._model.pos_embed.forward = self._model.pos_embed._orig_forward
def _minicpmv_resampler_forward(self, image_feature, pos_embed, key_padding_mask):
bs = image_feature.shape[0]
image_feature = self.kv_proj(image_feature) # B * L * D
image_feature = self.ln_kv(image_feature).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
q_bs = q.unsqueeze(1).repeat(1, bs, 1)
out = self.attn(q_bs, image_feature + pos_embed, image_feature, key_padding_mask=key_padding_mask)[
0
] # Q * B * D # L * B * D + L * B * D
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
x = self.ln_post(x)
x = x @ self.proj
return x
def _minicpmv_siglip_vis_embed_forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if position_ids is None:
batch_size = pixel_values.size(0)
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full(
size=(
batch_size,
max_nb_patches_h * max_nb_patches_w,
),
fill_value=0,
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
def _minicpmv_siglip_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, is_causal=attention_mask is None
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
def _minicpmv_siglip_transformer_forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
tgt_sizes: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_attention_mask = torch.ones(
size=(
batch_size,
pixel_values.size(2) // self.config.patch_size,
pixel_values.size(3) // self.config.patch_size,
),
dtype=torch.bool,
device=pixel_values.device,
)
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes,
position_ids=position_ids,
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
attention_mask = (
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
if not self._use_flash_attention_2
else patch_attention_mask
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
if not return_dict:
return (last_hidden_state, None) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=None,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class MiniCPMVResamplerModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(_minicpmv_resampler_forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
class MiniCPMVImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(_minicpmv_siglip_transformer_forward, model)
super().__init__(config, model, model_kwargs)
def __enter__(self):
super().__enter__()
self._model.embeddings._orig_forward = self._model.embeddings.forward
self._model.embeddings.forward = types.MethodType(_minicpmv_siglip_vis_embed_forward, self._model.embeddings)
if is_torch_version(">=", "2.0.0"):
for layer in self._model.encoder.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_minicpmv_siglip_attn_forward, layer.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
self._model.embeddings.forward = self._model.embeddings._orig_forward
if is_torch_version(">=", "2.0.0"):
for layer in self._model.encoder.layers:
layer.self_attn.forward = layer.self_attn._orig_forward
class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = model.encode_images
super().__init__(config, model, model_kwargs)
if not self._model.get_vision_tower().is_loaded:
self._model.get_vision_tower().load_model()
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
class InputEmbeddingPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
def forward(self, input):
return self.__orig_forward(input)
model.forward = types.MethodType(forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor):
return self.get_img_features(pixel_values)
class Phi3VisionImageEmbeddingsPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(phi3_vision_embeddings_forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
def minicpm3_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.shape
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(hidden_states.shape[0], hidden_states.shape[1], 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class MiniCPM3Patcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for block in self._model.model.layers:
block.self_attn._orig_forward = block.self_attn.forward
block.self_attn.forward = types.MethodType(minicpm3_attn_forward, block.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward
class DeepseekPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
self_attn = {
"deepseek_v3": deepseek_v3_attn_forward,
"deepseek_v2": deepseek_v2_attn_forward,
"deepseek": minicpm3_attn_forward,
}
self_attn_fwd = self_attn.get(self._model.config.model_type)
for block in self._model.model.layers:
if self_attn_fwd is not None:
block.self_attn._orig_forward = block.self_attn.forward
block.self_attn.forward = types.MethodType(self_attn_fwd, block.self_attn)
if hasattr(block.mlp, "moe_infer"):
block.mlp._org_moe_infer = block.mlp.moe_infer
block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward
if hasattr(block.mlp, "_orig_moe_infer"):
block.mlp.moe_infer = block.mlp._orig_moe_infer
def deepseek_v3_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def deepseek_v2_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L806
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.shape
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def deepseek_moe_infer(self, x, topk_ids, topk_weight):
cnts = torch.zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0).to(torch.long)
idxs = torch.argsort(topk_ids.view(-1))
sorted_tokens = x[idxs // topk_ids.shape[1]]
outputs = []
start_idx = torch.tensor(0, dtype=torch.long)
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
# difference with original: removed skiping expert if empty num_tokens
expert_id = i + self.ep_rank * self.experts_per_rank
expert = self.experts[expert_id]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
# difference with original: removed usage torch.new_empty if outputs empty
outs = torch.cat(outputs, dim=0)
new_x = torch.zeros_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.to(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.to(new_x.dtype)
)
return final_out
class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward
def forward_wrap(
self,
attention_mask,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
input_ids=None,
use_cache=True,
):
from transformers.cache_utils import DynamicCache
new_past_key_values = DynamicCache.from_legacy_cache(past_key_values)
result = self.__orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=new_past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
if past_key_values is not None:
result["past_key_values"] = result["past_key_values"].to_legacy_cache()
return result
model.forward = types.MethodType(forward_wrap, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
def patch_qwen2vl_vision_blocks(model, force_new_behaviour=False):
if not force_new_behaviour and is_transformers_version("<=", "4.48.99"):
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
def sdpa_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if is_transformers_version(">=", "4.49"):
if position_embeddings is None:
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().float()
sin = emb.sin().float()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
else:
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L430
# added attention_mask input propagation to self.attn
def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
else:
# Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L391
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
def sdpa_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().float()
sin = emb.sin().float()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
# Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L446
# added attention_mask input propagation to self.attn
def block_forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
for block in model.blocks:
block._orig_forward = block.forward
block.forward = types.MethodType(block_forward, block)
block.attn._orig_forward = block.attn.forward
block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn)
class Qwen2VLVisionEmbMergerPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
# added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
# separated patch_embed and rot_pos_emb calls for performing as part of another model
def image_embed_forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor
) -> torch.Tensor:
for blk in self.blocks:
hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
return self.merger(hidden_states)
model.forward = types.MethodType(image_embed_forward, model)
super().__init__(config, model, model_kwargs)
def __enter__(self):
patch_qwen2vl_vision_blocks(self._model)
super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
for block in self._model.blocks:
block.forward = block._orig_forward
block.attn.forward = block.attn._orig_forward
class Qwen2_5_VLVisionEmbMergerPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
super().__init__(config, model, model_kwargs)
model.__orig_forward = model.forward
# Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L405
# added attention_mask and window_attention_mask inputs instead cu_lens and window_cu_lens processing for its internal calculation model
# (unsupported by tracing due to cycle with dynamic len)
# separated patch_embed and rot_pos_emb calls for performing as part of another model
def image_embed_forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
window_attention_mask: torch.Tensor,
window_index: torch.Tensor,
rotary_pos_emb: torch.Tensor,
) -> torch.Tensor:
seq_len = hidden_states.shape[0]
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
attention_mask_now = attention_mask
else:
attention_mask_now = window_attention_mask
hidden_states = blk(
hidden_states, attention_mask=attention_mask_now, position_embeddings=position_embeddings
)
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
model.forward = types.MethodType(image_embed_forward, model)
super().__init__(config, model, model_kwargs)
def __enter__(self):
patch_qwen2vl_vision_blocks(self._model, force_new_behaviour=True)
super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
for block in self._model.blocks:
block.forward = block._orig_forward
block.attn.forward = block.attn._orig_forward
# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321
def _granite_moe_topk_gating_forward(self, hidden_states):
# compute the top_k routing decision
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
# compute number of input given to each expert
zeros = torch.zeros(
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
) # [num_tokens, num_experts]
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
expert_size = gates.long().sum(0) # [num_experts,]
# difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing
# sort and group input tokens according to expert assignment
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
# gather the gate values for grouped input tokens
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L281
def _granite_moe_parallel_experts_forward(self, inputs, expert_size):
output_list = []
# difference with original
# 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size)
# 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle
index_start = torch.tensor(0, dtype=torch.int64)
for i in range(self.num_experts):
next_index = index_start + expert_size[i]
output_list.append(F.linear(inputs[index_start:next_index], self.weight[i]))
index_start = next_index
results = torch.cat(output_list, dim=0)
return results
class GraniteMoEModelPatcher(LlamaModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
block_sparse_moe = layer.block_sparse_moe
block_sparse_moe.router._orig_forward = block_sparse_moe.router.forward
block_sparse_moe.router.forward = types.MethodType(
_granite_moe_topk_gating_forward, block_sparse_moe.router
)
block_sparse_moe.input_linear._orig_forward = block_sparse_moe.input_linear.forward
block_sparse_moe.input_linear.forward = types.MethodType(
_granite_moe_parallel_experts_forward, block_sparse_moe.input_linear
)
block_sparse_moe.output_linear._orig_forward = block_sparse_moe.output_linear.forward
block_sparse_moe.output_linear.forward = types.MethodType(
_granite_moe_parallel_experts_forward, block_sparse_moe.output_linear
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
block_sparse_moe = layer.block_sparse_moe
block_sparse_moe.router.forward = block_sparse_moe.router._orig_forward
block_sparse_moe.input_linear.forward = block_sparse_moe.input_linear._orig_forward
block_sparse_moe.output_linear.forward = block_sparse_moe.output_linear._orig_forward
# copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
def gpt_bigcode_attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
# The super dispatch is done in the forward.
raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.")
scale = None
if not self.scale_attn_weights:
scale = 1
# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key.shape[-2]
if self.multi_query:
query_length = query_shape[1]
# SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
#
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
if is_torch_version(">=", "2.2.0"):
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
else:
query_length = query_shape[-1]
# See the comment above.
if query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
# create a causal mask in case query_length == 1.
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
# different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
if attention_mask is not None:
attention_mask = attention_mask.to(query.dtype)
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0.0,
is_causal=is_causal,
scale=scale,
)
if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)
# Reshape is kind of expensive here, as it does a memory copy,
# but I did not manage to make away without it (logits do not match when using view)
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)
return sdpa_result, None
class GptBigCodeModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._orig_attn = layer.attn._attn
layer.attn._attn = types.MethodType(gpt_bigcode_attn, layer.attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._attn = layer.attn._orig_attn
class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
model.__orig_forward = model.forward
@functools.wraps(model.__orig_forward)
def patched_forward(*args, **kwargs):
from transformers.cache_utils import EncoderDecoderCache
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
return_legacy_cache = False
pkv_in_args = False
legacy_pkv = None
if "past_key_values" in kwargs:
legacy_pkv = kwargs.pop("past_key_values", None)
sign_names = list(signature.parameters.keys())
pkv_argument_index = sign_names.index("past_key_values")
if legacy_pkv is None and len(args) > pkv_argument_index:
legacy_pkv = args[pkv_argument_index]
pkv_in_args = True
if legacy_pkv is not None:
if isinstance(legacy_pkv, EncoderDecoderCache):
legacy_pkv = legacy_pkv.to_legacy_cache()
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
return_legacy_cache = True
if not pkv_in_args:
kwargs["past_key_values"] = pkv
else:
args[pkv_argument_index] = pkv
outputs = model.__orig_forward(*args, **kwargs)
if return_legacy_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
return outputs
model.forward = patched_forward
super().__init__(config, model, model_kwargs)
class SanaTextEncoderModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True)
if self._model.config._attn_implementation != "sdpa":
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
if is_transformers_version("<", "4.47.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
for layer in self._model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, None, True)
if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
class MiniCPMModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
for layer in model.model.layers:
if hasattr(layer, "scale_depth"):
layer.self_attn.o_proj.to(torch.float32)
layer.mlp.down_proj.to(torch.float32)
super().__init__(config, model, model_kwargs)
class CommonImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
if hasattr(model, "model") and hasattr(model.model, "get_image_features"):
model.forward = model.model.get_image_features
else:
model.forward = model.get_image_features
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147
def _gemma3_mm_update_causal_mask(
self, attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training: bool = False
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted
# form and requires no inversion or slicing.
return attention_mask
min_dtype = torch.finfo(torch.float16).min
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
# Apply bidirectional mask on images if token type ids are provided
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
return causal_mask
class Gemma3LMModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
model.__orig_forward = model.forward
if is_transformers_version("<", "4.52"):
model._update_causal_mask_mm = types.MethodType(_gemma3_mm_update_causal_mask, model)
else:
model.model._orig_update_causual_mask = model.model._update_causal_mask
model.model._update_causal_mask = types.MethodType(_gemma3_mm_update_causal_mask, model.model)
# Difference from original:
# uses Dynamic cache from legacy cache instead of HybridCache
# calculate causal mask from multimodal
def forward(
self, attention_mask, position_ids, past_key_values, token_type_ids, inputs_embeds, use_cache=True
):
from transformers.cache_utils import DynamicCache
pkv = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values[0][0].shape[-2]
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
forward_kwargs = {}
if is_transformers_version("<", "4.52"):
attention_mask = self._update_causal_mask_mm(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds
)
else:
forward_kwargs["token_type_ids"] = token_type_ids
result = self.__orig_forward(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=pkv,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**forward_kwargs,
)
upd_pkv = result["past_key_values"]
result["past_key_values"] = upd_pkv.to_legacy_cache()
return result
model.forward = types.MethodType(forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
if hasattr(self._model, "model") and hasattr(self._model.model, "_orig_update_causual_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causual_mask
class Idefics3ImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-SmolVLM-2/src/transformers/models/idefics3/modeling_idefics3.py#L999-L1005
def get_image_features(self, pixel_values, patch_attention_mask, patch_position_ids):
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
patch_position_ids=patch_position_ids,
).last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(image_hidden_states)
return image_hidden_states
model.__orig_forward = model.forward
model.forward = types.MethodType(get_image_features, model)
super().__init__(config, model, model_kwargs)
def __enter__(self):
# The difference from original code is only in getting patch_position_ids as input and propogation it into embeddings instead of calculation inside based on patch_attention_mask
# method for calculation position_ids is not pytorch tracing friendly due to cycle over batch size.
def transformer_forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
patch_position_ids: Optional[torch.IntTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import BaseModelOutput
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
patch_position_ids=patch_position_ids,
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
elif not self._use_flash_attention_2:
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
if not return_dict:
return (last_hidden_state,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def embeddings_forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
patch_position_ids: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if patch_position_ids is None:
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
else:
position_ids = patch_position_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
def attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
self._model.vision_model._orig_forward = self._model.vision_model.forward
self._model.vision_model.forward = types.MethodType(transformer_forward, self._model.vision_model)
self._model.vision_model.embeddings._orig_forward = self._model.vision_model.embeddings.forward
self._model.vision_model.embeddings.forward = types.MethodType(
embeddings_forward, self._model.vision_model.embeddings
)
for layer in self._model.vision_model.encoder.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(attn_forward, layer.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
self._model.vision_model.forward = self._model.vision_model._orig_forward
self._model.vision_model.embeddings.forward = self._model.vision_model.embeddings._orig_forward
for layer in self._model.vision_model.encoder.layers:
layer.self_attn.forward = layer.self_attn._orig_forward
# Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/decoder_models.py#L367
def _blenderbot_attn_forward_legacy(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions or layer_head_mask is not None:
return self._orig_forward(
hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
)
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
# Adopted from https://github.com/huggingface/transformers/blob/v4.52.3/src/transformers/models/blenderbot/modeling_blenderbot.py#L156
def _blenderbot_attn_forward_new(
self,
hidden_states: torch.Tensor,
key_value_states=None,
past_key_value=None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from transformers.cache_utils import EncoderDecoderCache
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
if output_attentions or layer_head_mask is not None:
return self._orig_forward(
hidden_states,
key_value_states,
past_key_value,
attention_mask,
layer_head_mask,
output_attentions,
cache_position,
)
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
curr_past_key_value = past_key_value.self_attention_cache
else:
curr_past_key_value = past_key_value
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
proj_shape = (bsz, self.num_heads, -1, self.head_dim)
# difference with original, removed query_states = query_states.reshape(*proj_shape) * self.scale as scale is part of SDPA
query_states = query_states.reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
# Difference with original, use SDPA instead of eager attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
if is_transformers_version(">=", "4.52"):
_blenderbot_attn_forward = _blenderbot_attn_forward_new
else:
_blenderbot_attn_forward = _blenderbot_attn_forward_legacy
def modulewise_patch(model, module_cls, patch_forward):
for _, module in model.named_children():
if isinstance(module, module_cls):
module._orig_forward = module.forward
module.forward = types.MethodType(patch_forward, module)
return
else:
if len(list(module.children())) > 0:
modulewise_patch(module, module_cls, patch_forward)
def modulewise_unpatch(model, module_cls):
for _, module in model.named_children():
if isinstance(module, module_cls):
if hasattr(module, "_orig_forward"):
module.forward = module._orig_forward
else:
if len(list(module.children())) > 0:
modulewise_unpatch(module, module_cls)
class BlenderbotModelPatcher(Seq2SeqModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.49.0"):
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention
modulewise_patch(self._model, BlenderbotAttention, _blenderbot_attn_forward)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.49.0"):
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention
modulewise_unpatch(self._model, BlenderbotAttention)
class BlenderbotSmallModelPatcher(Seq2SeqModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.49.0"):
from transformers.models.blenderbot_small.modeling_blenderbot_small import BlenderbotSmallAttention
modulewise_patch(self._model, BlenderbotSmallAttention, _blenderbot_attn_forward)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.49.0"):
from transformers.models.blenderbot_small.modeling_blenderbot_small import BlenderbotSmallAttention
modulewise_unpatch(self._model, BlenderbotSmallAttention)
class BlenderbotStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, BlenderbotModelPatcher):
pass
class BlenderbotSmallStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, BlenderbotSmallModelPatcher):
pass
class PegasusModelPatcher(Seq2SeqModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.49.0"):
from transformers.models.pegasus.modeling_pegasus import PegasusAttention
modulewise_patch(self._model, PegasusAttention, _blenderbot_attn_forward)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.49.0"):
from transformers.models.pegasus.modeling_pegasus import PegasusAttention
modulewise_unpatch(self._model, PegasusAttention)
# Copied from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L596
# No modifications, transformers>=4.52.0 this method realization breaks tracing
def _qwen2moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
final_hidden_states = final_hidden_states + shared_expert_output
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
class Qwen2MoEPatcher(UpdateCausalMaskModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.52.0"):
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
modulewise_patch(self._model, Qwen2MoeSparseMoeBlock, _qwen2moe_sparse_block_forward)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.52.0"):
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
modulewise_unpatch(self._model, Qwen2MoeSparseMoeBlock)
class PegasusStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, PegasusModelPatcher):
pass
class MarianModelPatcher(Seq2SeqModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.49.0"):
from transformers.models.marian.modeling_marian import MarianAttention
modulewise_patch(self._model, MarianAttention, _blenderbot_attn_forward)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.49.0"):
from transformers.models.marian.modeling_marian import MarianAttention
modulewise_unpatch(self._model, MarianAttention)
class MarianStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, MarianModelPatcher):
pass
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L698
# this is a patch to avoid PyTorch FE issue
# with the same tensor names on input and intermediate tensor for speaker_embeddings
def speecht5_decoder_prenet_forward(
self,
input_values: torch.Tensor,
speaker_embeddings: Optional[torch.Tensor] = None,
):
inputs_embeds = input_values
for layer in self.layers:
inputs_embeds = torch.nn.functional.relu(layer(inputs_embeds))
inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout)
inputs_embeds = self.final_layer(inputs_embeds)
inputs_embeds = self.encode_positions(inputs_embeds)
if speaker_embeddings is not None:
# this is a patch to avoid for PyTorch FE issue!!!
# with the same tensor names on input and intermediate tensor in a model
speaker_embeddings_norm = torch.nn.functional.normalize(speaker_embeddings)
speaker_embeddings_unsqueeze = speaker_embeddings_norm.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1)
inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings_unsqueeze], dim=-1)
inputs_embeds = torch.nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))
return inputs_embeds
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L993
# this is a patch to avoid CPU plugin issue that is happened on 16-th iteration of token generation
# values computed by self-attention attn_output = torch.bmm(attn_probs, value_states) in a decoder gets incorrect
def speecht5_attention_forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
output_attentions: bool = False,
serialize: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# relative attention bias
if position_bias is not None:
reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1)
rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
rel_pos_bias = rel_pos_bias.transpose(0, 1).view(
bsz * self.num_heads, position_bias.size(0), position_bias.size(1)
)
attn_weights += rel_pos_bias
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# this is a patch to avoid CPU plugin issue!!!
# issue is happened on 16-th iteration of token generation
# since 16-th iteration of token generation, values computed by self-attention in a decoder gets incorrect
eps = 1e-30
attn_output = torch.bmm(attn_probs + eps, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L1175
# this is a patch for a model to avoid incorrect tracing
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple are computed using encoder_hidden_states
def speecht5_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
serialize: bool = False,
):
residual = hidden_states
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
serialize=serialize,
)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
# this is a patch for a model to avoid incorrect tracing!!!
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
# are computed using encoder_hidden_states
if past_key_value is not None and len(past_key_value) > 3:
cross_attn_past_key_value = past_key_value[-2:]
else:
cross_attn_past_key_value = None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
class OVSpeechT5ModelPatcher(ModelPatcher):
def __enter__(self):
if self.real_config._behavior != "vocoder":
setattr(self._model, self.orig_forward_name, self.patched_forward)
if self.real_config._behavior == "decoder":
self._model.speecht5.decoder.prenet.__orig_forward = self._model.speecht5.decoder.prenet.forward
self._model.speecht5.decoder.prenet.forward = types.MethodType(
speecht5_decoder_prenet_forward, self._model.speecht5.decoder.prenet
)
for layer in self._model.speecht5.decoder.wrapped_decoder.layers:
layer.__orig_forward = layer.forward
layer.forward = types.MethodType(speecht5_decoder_layer_forward, layer)
layer.self_attn.__orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(speecht5_attention_forward, layer.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
if self.real_config._behavior != "vocoder":
setattr(self._model, self.orig_forward_name, self.orig_forward)
if self.real_config._behavior == "decoder":
self._model.speecht5.decoder.prenet.forward = types.MethodType(
self._model.speecht5.decoder.prenet.__orig_forward, self._model.speecht5.decoder.prenet
)
for layer in self._model.speecht5.decoder.wrapped_decoder.layers:
layer.forward = types.MethodType(layer.__orig_forward, layer)
layer.self_attn.forward = types.MethodType(layer.self_attn.__orig_forward, layer.self_attn)
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)
def patched_encoder_forward(
input_ids: torch.FloatTensor = None,
):
encoder_attention_mask = torch.ones_like(input_ids)
hidden_states = self._model.prenet(input_ids)
encoder_out = self._model.wrapped_encoder(
hidden_states=hidden_states,
attention_mask=encoder_attention_mask,
return_dict=True,
)
# downsample encoder attention mask
if isinstance(model, SpeechT5EncoderWithSpeechPrenet):
encoder_attention_mask = model.prenet._get_feature_vector_attention_mask(
encoder_out[0].shape[1], encoder_attention_mask
)
result = {
"encoder_outputs": encoder_out.last_hidden_state,
"encoder_attention_mask": encoder_attention_mask,
}
return result
def patched_decoder_forward(
inputs_embeds=None,
speaker_embeddings=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
):
return_legacy_cache = False
if past_key_values is not None:
only_self_cache = [cache_item[:2] for cache_item in past_key_values]
past_key_values = only_self_cache
return_legacy_cache = True
output_sequence = inputs_embeds
output_cross_attentions = False
bsz = output_sequence.size(0)
# Run the decoder prenet on the entire output sequence.
decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
# Run the decoder layers on the last element of the prenet output.
decoder_out = model.speecht5.decoder.wrapped_decoder(
hidden_states=decoder_hidden_states[:, -1:],
attention_mask=None,
encoder_hidden_states=encoder_hidden_states[0],
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=True,
output_attentions=output_cross_attentions,
return_dict=True,
)
# if output_cross_attentions:
# cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
last_decoder_output = decoder_out.last_hidden_state.squeeze(1)
# Predict the new mel spectrum for this step in the sequence.
spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins)
# Extend the output sequence with the new mel spectrum.
new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)
output_sequence_out = torch.cat((output_sequence, new_spectrogram), dim=1)
# Predict the probability that this is the stop token.
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))
if return_legacy_cache:
only_self_cache = [cache_item[:2] for cache_item in decoder_out.past_key_values]
past_key_values = only_self_cache
result = {
"output_sequence_out": output_sequence_out,
"spectrum": spectrum,
"prob": prob,
"past_key_values": past_key_values,
}
return result
def patched_postnet_forward(raw_spectrogram: torch.FloatTensor):
raw_spectrogram = raw_spectrogram.transpose(0, 1).flatten(1, 2)
spectrogram = model.speech_decoder_postnet.postnet(raw_spectrogram)
result = {"postnet_spectrogram": spectrogram}
return result
def patched_vocoder_forward(spectrogram: torch.FloatTensor):
waveform = model(spectrogram)
result = {"waveform": waveform}
return result
if self.real_config._behavior == "encoder":
self.patched_forward = patched_encoder_forward
elif self.real_config._behavior == "decoder":
self.patched_forward = patched_decoder_forward
elif self.real_config._behavior == "postnet":
self.patched_forward = patched_postnet_forward
elif self.real_config._behavior == "vocoder":
self.patched_forward = patched_vocoder_forward
else:
raise ValueError("Unknown ")
self.orig_forward = self.patched_forward
class Phi4MMLanguageModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
if hasattr(model.config, "vision_lora") and model.config.vision_lora is not None:
model.set_lora_adapter("vision")
if hasattr(model.config, "speech_lora") and model.config.speech_lora is not None:
model.set_lora_adapter("speech")
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L2156-L2178
# moved audio and vision features processing outside model
def lm_forward(self, inputs_embeds, attention_mask, position_ids, past_key_values, use_cache=True):
from transformers.cache_utils import DynamicCache
pkv = DynamicCache.from_legacy_cache(past_key_values)
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
past_key_values=pkv,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states)
return (logits, outputs.past_key_values.to_legacy_cache())
model.__orig_forward = model.forward
model.forward = types.MethodType(lm_forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
class Phi4MMAudioForwardEmbeddingsPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1121
def forward(self, audio_input):
if hasattr(self, "_forward_embeddings_code"):
audio_input, masks = self._forward_embeddings_core(audio_input, None)
else:
audio_input, masks = self.embed(audio_input, None)
return audio_input
model.__orig_forward = model.forward
model.forward = types.MethodType(forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
class Phi4MMAudioEncoderPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1201-L1212
def forward(self, audio_feature, audio_mask):
if hasattr(self, "init_relative_attention_bias"):
relative_attention_bias = self.init_relative_attention_bias(audio_feature)
_simplified_path = self.extra_layer_output_idx == -1 and relative_attention_bias is None
if _simplified_path:
audio_feature, *_ = self.encoders(audio_feature, None, None, audio_mask)
else:
for layer in self.encoders:
audio_feature, _, _, _ = layer(
audio_feature,
None,
None,
audio_mask,
relative_attention_bias=relative_attention_bias,
)
else:
relative_attention_bias = self.relative_attention_bias_layer(audio_feature)
attention_mask = audio_mask.unsqueeze(1) + relative_attention_bias
for layer in self.encoders:
audio_feature = layer(audio_feature, attention_mask)
return audio_feature
model.__orig_forward = model.forward
model.forward = types.MethodType(forward, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
class Phi4MMVisionEmbeddingsPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
def get_img_features_legacy(
self, pixel_values: torch.FloatTensor, patch_attention_mask=None, patch_position_ids=None
) -> torch.FloatTensor:
LAYER_IDX = self.layer_idx
TYPE_FEATURE = self.type_feature
if self.freeze_img_processor:
with torch.no_grad():
if patch_attention_mask is not None:
img_processor_output = self.img_processor(
pixel_values,
output_hidden_states=True,
patch_attention_mask=patch_attention_mask,
position_ids=patch_position_ids,
)
else:
img_processor_output = self.img_processor(
pixel_values, output_hidden_states=True, position_ids=patch_position_ids
)
img_feature = img_processor_output.hidden_states[LAYER_IDX]
else:
if patch_attention_mask is not None:
img_processor_output = self.img_processor(
pixel_values,
output_hidden_states=True,
patch_attention_mask=patch_attention_mask,
position_ids=patch_position_ids,
)
else:
img_processor_output = self.img_processor(
pixel_values, output_hidden_states=True, position_ids=patch_position_ids
)
img_feature = img_processor_output.hidden_states[LAYER_IDX]
if TYPE_FEATURE == "patch":
patch_feature = img_feature
if self.image_token_compression is not None:
# reshape to 2D tensor
width = int(math.sqrt(patch_feature.size(1)))
patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1))
# convert to NCHW
patch_feature = patch_feature.permute(0, 3, 1, 2)
if getattr(self, "img_processor_padding", None) is not None:
patch_feature = self.img_processor_padding(patch_feature)
patch_feature = self.image_token_compression(patch_feature)
# convert to NHWC
patch_feature = patch_feature.permute(0, 2, 3, 1)
patch_feature = patch_feature.view(
-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)
)
elif getattr(self, "img_processor_padding", None) is not None:
width = int(math.sqrt(patch_feature.size(1)))
patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1))
# convert to NCHW
patch_feature = patch_feature.permute(0, 3, 1, 2)
patch_feature = self.img_processor_padding(patch_feature)
# convert to NHWC
patch_feature = patch_feature.permute(0, 2, 3, 1)
patch_feature = patch_feature.view(
-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)
)
return patch_feature
if TYPE_FEATURE == "cls_patch":
if self.image_token_compression is not None:
# reshape to 2D tensor
patch_feature = img_feature[:, 1:]
cls_feature = img_feature[:, 0]
width = math.sqrt(patch_feature.size(1))
patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1))
patch_feature = self.image_token_compression(patch_feature)
patch_feature = patch_feature.view(-1, patch_feature.size(-2) * patch_feature.size(-1))
img_feature = torch.cat([cls_feature, patch_feature], dim=1)
return img_feature
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L649
# added possibility to provide patch_position_ids
def get_img_features(
self, pixel_values: torch.FloatTensor, patch_attention_mask=None, patch_position_ids=None
):
img_processor_output = self.img_processor(
pixel_values,
patch_attention_mask=patch_attention_mask,
output_hidden_states=True,
position_ids=patch_position_ids,
)
img_feature = img_processor_output.hidden_states[self.layer_idx]
patch_feature = img_feature
# reshape to 2D tensor
width = int(math.sqrt(patch_feature.size(1)))
patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1))
# convert to NCHW
patch_feature = patch_feature.permute(0, 3, 1, 2)
if getattr(self, "img_processor_padding", None) is not None:
patch_feature = self.img_processor_padding(patch_feature)
patch_feature = self.image_token_compression(patch_feature)
# convert to NHWC
patch_feature = patch_feature.permute(0, 2, 3, 1)
patch_feature = patch_feature.view(
-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)
)
return patch_feature
model.__orig_forward = model.forward
model.forward = types.MethodType(
get_img_features_legacy if hasattr(model, "type_feature") else get_img_features, model
)
super().__init__(config, model, model_kwargs)
def __enter__(self):
super().__enter__()
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L563
# added possibility calculate position_ids outside
def transformer_fwd(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_attention_mask = torch.ones(
size=(
batch_size,
pixel_values.size(2) // self.config.patch_size,
pixel_values.size(3) // self.config.patch_size,
),
dtype=torch.bool,
device=pixel_values.device,
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, position_ids=position_ids
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
attention_mask = None
else:
attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(
hidden_state=last_hidden_state,
attention_mask=patch_attention_mask,
)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L76
# used SDPA instead of eager attention
def attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L488
# moved position_ids calculation outside of model
def embd_forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
position_ids: torch.FloatTensor = None,
) -> torch.Tensor:
batch_size = pixel_values.size(0)
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if position_ids is None:
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full(
size=(
batch_size,
max_nb_patches_h * max_nb_patches_w,
),
fill_value=0,
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
if (
getattr(self._model.img_processor.encoder.layers[0].self_attn.config, "_attn_implementation", "eager")
!= "sdpa"
):
for layer in self._model.img_processor.encoder.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(attn_forward, layer.self_attn)
self._model.img_processor._orig_forward = self._model.img_processor.forward
self._model.img_processor.forward = types.MethodType(transformer_fwd, self._model.img_processor)
self._model.img_processor.embeddings._orig_forward = self._model.img_processor.embeddings.forward
self._model.img_processor.embeddings.forward = types.MethodType(
embd_forward, self._model.img_processor.embeddings
)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
for layer in self._model.img_processor.encoder.layers:
if hasattr(layer.self_attn, "_orig_frward"):
layer.self_attn.forward = layer.self_attn._orig_forward
self._model.img_processor.forward = self._model.img_processor._orig_forward
self._model.img_processor.embeddings.forward = self._model.img_processor.embeddings._orig_forward
class Llama4ImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L1732-L1741
def get_image_embeddings(self, pixel_values):
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.config.vision_config.vision_feature_layer,
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
)
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
return projected_vision_flat
model.forward = types.MethodType(get_image_embeddings, model)
super().__init__(config, model, model_kwargs)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
# modified from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L229
# use real cos / sin instead of complex
def llama4_rope_forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# use real cos / sin instead of complex
# Modified from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/llama4/modeling_llama4.py#L247
# Based on https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py#L292
# Native DeepSeek apply rotary emb works in the same way like llama4 apply rotary emb
def llama4_apply_rotary_emb(
xq: torch.Tensor, xk: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
from transformers.models.llama.modeling_llama import rotate_half
xq_ = xq.float()
xk_ = xk.float()
cos = cos.unsqueeze(2)
sin = sin.unsqueeze(2)
b, h, s, d = xq_.shape
xq_ = xq_.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = xk_.shape
xk_ = xk_.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (xq_ * cos) + (rotate_half(xq_) * sin)
k_embed = (xk_ * cos) + (rotate_half(xk_) * sin)
return q_embed.type_as(xq), k_embed.type_as(xk)
# https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L329
# use real cos / sin instead of complex
def llama4_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from transformers.models.llama4.modeling_llama4 import ALL_ATTENTION_FUNCTIONS, eager_attention_forward
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape)
key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if self.use_rope: # the 16E model skips rope for long context on certain layers
cos, sin = position_embeddings[0], position_embeddings[1]
query_states, key_states = llama4_apply_rotary_emb(
query_states, key_states, cos.to(query_states.device), sin.to(query_states.device)
)
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
query_states = self.qk_norm(query_states)
key_states = self.qk_norm(key_states)
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
if self.attn_temperature_tuning and not self.use_rope:
attn_scales = (
torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
)
attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1
query_states = (query_states * attn_scales).to(query_states.dtype)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
attention_interface = eager_attention_forward
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# modified from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L157
# due to openvino transformations issue removed routed_out.view(-1, hidden_dim) in scatter_add_
def llama4_moe_forward(self, hidden_states):
batch, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
router_logits = self.router(hidden_states).transpose(0, 1)
tokens_per_expert = batch * seq_len
router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1)
router_scores = (
torch.full_like(router_logits.transpose(0, 1), float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
# We do this to make sure we have -inf for non topK tokens before going through the !
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
router_indices = (
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim)
routed_in = torch.gather(
input=hidden_states,
dim=0,
index=router_indices,
).to(hidden_states.device)
# we gather inputs corresponding to each expert based on the router indices
routed_in = routed_in * router_scores.reshape(-1, 1)
routed_out = self.experts(routed_in)
out = self.shared_expert(hidden_states)
# now that we finished expert computation -> we scatter add because we gathered previously
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP!
out.scatter_add_(dim=0, index=router_indices, src=routed_out)
return out, router_scores
class Llama4TextModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward
self._model.model.rotary_emb.forward = types.MethodType(llama4_rope_forward, self._model.model.rotary_emb)
for layer in self._model.model.layers[: self._model.model.config.num_hidden_layers]:
if layer.is_moe_layer:
layer.feed_forward._orig_forward = layer.feed_forward.forward
layer.feed_forward.forward = types.MethodType(llama4_moe_forward, layer.feed_forward)
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(llama4_attn_forward, layer.self_attn)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward
for layer in self._model.model.layers[: self._model.model.config.num_hidden_layers]:
if layer.is_moe_layer:
layer.feed_forward.forward = layer.feed_forward._orig_forward
layer.self_attn.forward = layer.self_attn._orig_forward