optimum/exporters/openvino/model_patcher.py (4,787 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import inspect import logging as log import math import types from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from transformers import PreTrainedModel, TFPreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_tf_available from optimum.exporters.onnx.base import OnnxConfig from optimum.exporters.onnx.model_patcher import ( UNSUPPORTED_OPS_PATCHING_SPEC, DecoderModelPatcher, ModelPatcher, PatchingSpec, Seq2SeqModelPatcher, override_arguments, ) from optimum.intel.utils.import_utils import ( _openvino_version, _torch_version, _transformers_version, is_diffusers_version, is_openvino_version, is_torch_version, is_transformers_version, ) if TYPE_CHECKING: from transformers.cache_utils import Cache from transformers.modeling_utils import PreTrainedModel from optimum.exporters.onnx.config import OnnxConfig if is_tf_available(): from transformers.modeling_tf_utils import TFPreTrainedModel def ov_compatible_repeat_interleave(input_tensor, repeats, dim=None, output_size=None): """ Custom implementation of torch.repeat_interleave without using torch.repeat_interleave. Args: input_tensor (torch.Tensor): The input tensor. repeats (int or torch.Tensor): The number of repetitions for each element. dim (int, optional): The dimension along which to repeat. Defaults to None. Returns: torch.Tensor: The repeated tensor. """ result = torch.repeat_interleave(input_tensor, repeats=repeats, dim=dim) return result def patch_unsupported_ops(): spec_idx = -1 for idx, spec in enumerate(UNSUPPORTED_OPS_PATCHING_SPEC): if spec.name == "repeat_interleave": spec_idx = idx break repreate_interlive_spec = PatchingSpec( torch.Tensor, "repeat_interleave", ov_compatible_repeat_interleave, torch.Tensor.repeat_interleave ) if spec_idx != -1: UNSUPPORTED_OPS_PATCHING_SPEC[spec_idx] = repreate_interlive_spec else: UNSUPPORTED_OPS_PATCHING_SPEC.append(repreate_interlive_spec) BETTERTRANSFORMER_IGNORE = [ "codegen", ] # in transformers 4.45 gpt_neo has SDPA if is_transformers_version(">=", "4.44.99"): BETTERTRANSFORMER_IGNORE.append("gpt_neo") patch_unsupported_ops() def patch_model_with_bettertransformer(model): COLOR_RED = "\033[1;31m" COLOR_RESET = "\033[0m" # check that the model has not yet been pathced if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: return model if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"): log.warning( COLOR_RED + "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. " f"This Python environment has Transformers {_transformers_version} and PyTorch {_torch_version}. " "Consider upgrading PyTorch and Transformers, for example by running " "`pip install --upgrade --upgrade-strategy eager optimum[openvino]`, and export the model again" + COLOR_RESET ) if ( getattr(model.config, "model_type") in {"gpt_bigcode", "llama", "gemma"} and is_transformers_version(">=", "4.38") and is_openvino_version("<", "2024.1.0-14612") ): # display commit-id only when a nightly/prerelease of OpenVINO is installed. display_version = ( _openvino_version.split("-")[0] if is_openvino_version("<=", "2024.0.0-14509") else _openvino_version ) log.warning( COLOR_RED + f"[WARNING] Stateful models are not supported for Llama, Gemma and GPTBigCode with Transformers " f"{_transformers_version} and OpenVINO {display_version}. For good performance, consider using a nightly OpenVINO build: " "https://docs.openvino.ai/2024/get-started/install-openvino.html. For gpt-bigcode and llama models, " "it is also an option to downgrade transformers: `pip install transformers==4.37.2`" + COLOR_RESET ) # model already has required SDPA implementation if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": return model if model.config.model_type in BETTERTRANSFORMER_IGNORE: return model try: model = model.to_bettertransformer() except Exception as e: log.warning( f"Cannot apply model.to_bettertransformer because of the exception:\n{e}." " Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention" ) return model return model def patch_update_causal_mask( model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False ): if is_transformers_version(">=", transformers_version): inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model if inner_model is not None: if hasattr(inner_model, "_update_causal_mask"): inner_model._orig_update_causal_mask = inner_model._update_causal_mask patch_fn = patch_fn or _llama_gemma_update_causal_mask inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model) def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False): inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model if inner_model is not None and hasattr(inner_model, "_orig_update_causal_mask"): inner_model._update_causal_mask = inner_model._orig_update_causal_mask # initialization of sin/cos cached in bf16/fp16 leads to accuracy loss # reinitialize them to save in float32 before export def _reinitialize_cos_sin_cached_fp32(rotary_emb): if rotary_emb.cos_cached.dtype != torch.float32: rotary_emb._set_cos_sin_cache( seq_len=rotary_emb.max_position_embeddings, device=rotary_emb.inv_freq.device, dtype=torch.float32 ) def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) if is_transformers_version("<", "4.37.0"): current_hidden_states = expert_layer(current_state, routing_weights[top_x, idx, None]) else: current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class MixtralModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() patch_update_causal_mask(self._model, "4.42.0") for layer in self._model.model.layers: layer.block_sparse_moe._unpatched_forward = layer.block_sparse_moe.forward layer.block_sparse_moe.forward = types.MethodType( _mixtral_sparse_moe_block_forward, layer.block_sparse_moe ) if is_transformers_version("<", "4.44.99"): _reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.model, "_orig_update_causal_mask"): self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask for layer in self._model.model.layers: layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward class ArcticModelPatcher(MixtralModelPatcher): def __enter__(self): # model initialize some weights for matrix multiplication in bfloat16, that lead to inconsistency of dtype try: self._model.to(torch.float32) except Exception: pass super().__enter__() def _chatglm_transformer_forward( self, input_ids, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, full_attention_mask: Optional[torch.BoolTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) if getattr(self, "pre_seq_len", None) is not None: if past_key_values is None: past_key_values = self.get_prompt( batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype, ) if attention_mask is not None: attention_mask = torch.cat( [ attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask, ], dim=-1, ) if full_attention_mask is None: if past_key_values is not None: full_attention_mask = torch.ones( batch_size, seq_length, seq_length, device=input_ids.device, dtype=torch.float, ) * float("-inf") full_attention_mask.triu_(diagonal=1) past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[0] if past_length: full_attention_mask = torch.cat( ( torch.zeros(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask, ), dim=-1, ) full_attention_mask.unsqueeze_(1) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, ) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor): mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype) if query_layer.shape[2] == key_layer.shape[2]: tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1) mask.masked_fill_(tmp_mask, float("-inf")) context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=mask ) return context_layer def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None: context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer) else: context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) return context_layer def _glm4_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) causal_mask.masked_fill_(attention_mask, float("-inf")) context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, causal_mask) context_layer = context_layer.transpose(1, 2).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) return context_layer class ChatGLMModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): super().__init__(config, model, model_kwargs) self.is_v4 = hasattr(self._model.config, "rope_ratio") def __enter__(self): super().__enter__() if not self.is_v4: self._model.transformer._orig_forward = self._model.transformer.forward self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer) for block in self._model.transformer.encoder.layers: block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward block.self_attention.core_attention.forward = types.MethodType( _chatglm2_core_attention_forward if not self.is_v4 else _glm4_core_attention_forward, block.self_attention.core_attention, ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.transformer, "_orig_forward"): self._model.transformer.forward = self._model.transformer._orig_forward for block in self._model.transformer.encoder.layers: block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward # adopted from # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965 # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058 def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): from transformers.modeling_attn_mask_utils import AttentionMaskConverter if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None: # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): return None dtype, device = input_tensor.dtype, input_tensor.device # difference with original modeling # using minimum from dtype with larger bandwith (floa32) may lead to overflow # during execution on platforms with default lower precision (bfloat16, float16) min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] # difference with original modeling if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache if past_seen_tokens is not None: current_length = past_seen_tokens + sequence_length + 1 # TODO : remove after support of transformers >= v4.40.0 else: current_length = cache_position[-1] + 1 target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length # difference with original modeling causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: offset = cache_position[0] else: offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype causal_mask[ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] ] = mask_slice if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask # adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036 def _llama_gemma_update_causal_mask_latest( self, attention_mask, input_tensor, cache_position, past_key_values, output_attentions, ): from transformers.cache_utils import StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device # difference with original modeling # using minimum from dtype with larger bandwith (floa32) may lead to overflow # during execution on platforms with default lower precision (bfloat16, float16) min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: # difference with original modeling causal_mask = ( torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask # TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0 if is_transformers_version(">", "4.40.2"): _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest else: _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None): # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L104 _seq_len = torch.max(position_ids) + 1 if seq_len is None else seq_len if _seq_len > self.embed_positions.shape[0]: if seq_len is None: return self._orig_forward(x, position_ids) else: return self._orig_forward(x, position_ids, seq_len) sincos = self.embed_positions[position_ids] sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) return cos, sin def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000, inv_freq=None) -> torch.Tensor: # adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101 if inv_freq is None: inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1) return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1) def register_sin_cos_buffer(model): max_positions = model.config.max_position_embeddings # cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step # use precomputed rotary_emb = model.model.layers[0].self_attn.rotary_emb dim, base = None, None inv_freq = getattr(rotary_emb, "inv_freq", None) if inv_freq is None: base = rotary_emb.base dim = rotary_emb.dim embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq) for layer in model.model.layers: layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions) layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward layer.self_attn.rotary_emb.forward = types.MethodType( llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb ) class LlamaModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() # llama/gemma has some accuracy issues with bf16 with transformers >= 4.39 # fill causal mask in slightly different way for avoid overflow on some platforms patch_update_causal_mask(self._model, "4.39.0", "model" if hasattr(self._model, "model") else "transformer") def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model, "model" if hasattr(self._model, "model") else "transformer") # copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42 def _mistral_update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: "Cache", use_cache: bool, output_attentions: bool, ): from transformers.cache_utils import SlidingWindowCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self._attn_implementation == "flash_attention_2": if attention_mask is not None and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. # cache_position must be valid here no matter which cache we use past_seen_tokens = cache_position[0] if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] # SlidingWindowCache if using_sliding_window_cache: target_length = max(sequence_length, self.config.sliding_window) # StaticCache elif using_static_cache: target_length = past_key_values.get_max_length() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if self.config.sliding_window is not None: if not using_sliding_window_cache or sequence_length > self.config.sliding_window: exclude_mask = exclude_mask.bitwise_or( torch.arange(target_length, device=device) <= (cache_position.reshape(-1, 1) - self.config.sliding_window) ) causal_mask *= exclude_mask causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask class MistralModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"): # apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model) else: for layer in self._model.model.layers: if hasattr(layer.self_attn, "rotary_emb"): _reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.model, "_orig_update_causal_mask"): self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask for layer in self._model.model.layers: if hasattr(layer.self_attn, "rotary_emb") and hasattr(layer.self_attn.rotary_emb, "_orig_forward"): layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward SUPPORT_SDPA = is_torch_version(">", "2.1.0") def _qwen_rotate_half(x): from einops import rearrange x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def _qwen_apply_rotary_pos_emb(t, freqs): cos, sin = freqs rot_dim = freqs[0].shape[-1] cos, sin = freqs t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] t_ = t_.float() t_pass_ = t_pass_.float() t_ = (t_ * cos) + (_qwen_rotate_half(t_) * sin) return torch.cat((t_, t_pass_), dim=-1).type_as(t) def _qwen_quantize_cache_v(fdata, bits, qmax, qmin): # b, s, head, h-dim->b, head, s, h-dim qtype = torch.uint8 device = fdata.device shape = fdata.shape fdata_cal = torch.flatten(fdata, 2) fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) # Compute params if qmax.device != fmax.device: qmax = qmax.to(device) qmin = qmin.to(device) scale = (fmax - fmin) / (qmax - qmin) zero = qmin - fmin / scale scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() # Quantize res_data = fdata / scale + zero qdata = torch.clamp(res_data, qmin, qmax).to(qtype) return qdata.contiguous(), scale, zero def _qwen_attention_forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ): mixed_x_layer = self.c_attn(hidden_states) query, key, value = mixed_x_layer.split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) if rotary_pos_emb_list is not None: cur_len = query.shape[1] if len(rotary_pos_emb_list) == 1: rotary_pos_emb = rotary_pos_emb_list[0] rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] rotary_pos_emb = (rotary_pos_emb,) * 2 q_pos_emb, k_pos_emb = rotary_pos_emb # Slice the pos emb for current inference query = _qwen_apply_rotary_pos_emb(query, q_pos_emb) key = _qwen_apply_rotary_pos_emb(key, k_pos_emb) else: query_list = [] key_list = [] for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] rotary_pos_emb = (rotary_pos_emb,) * 2 q_pos_emb, k_pos_emb = rotary_pos_emb # Slice the pos emb for current inference query_list += [_qwen_apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)] key_list += [_qwen_apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)] query = torch.cat(query_list, dim=0) key = torch.cat(key_list, dim=0) if self.use_cache_quantization: key = _qwen_quantize_cache_v(key.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) value = _qwen_quantize_cache_v(value.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] if self.use_cache_quantization: # use_cache_quantization: # present=((q_key,key_scale,key_zero_point), # (q_value,value_scale,value_zero_point)) key = ( torch.cat((past_key[0], key[0]), dim=2), torch.cat((past_key[1], key[1]), dim=2), torch.cat((past_key[2], key[2]), dim=2), ) value = ( torch.cat((past_value[0], value[0]), dim=2), torch.cat((past_value[1], value[1]), dim=2), torch.cat((past_value[2], value[2]), dim=2), ) else: # not use_cache_quantization: # present=(key,value) key = torch.cat((past_key, key), dim=1) value = torch.cat((past_value, value), dim=1) if use_cache: present = (key, value) else: present = None if self.use_logn_attn and not self.training: if self.use_cache_quantization: seq_start = key[0].size(2) - query.size(1) seq_end = key[0].size(2) else: seq_start = key.size(1) - query.size(1) seq_end = key.size(1) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) if self.use_flash_attn and not self.is_fp32 and query.is_cuda: q, k, v = query, key, value attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) else: registered_causal_mask = torch.tril( torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device) ).view(1, 1, key.size(1), key.size(1)) query = query.permute(0, 2, 1, 3) if not self.use_cache_quantization: key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) if not self.use_cache_quantization and SUPPORT_SDPA: # For performance, using constant tril to generate causal_mask causal_mask = self.bias[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)] if attention_mask is not None: attention_mask = attention_mask.expand(-1, -1, query.size(2), -1).masked_fill( ~causal_mask, torch.finfo(query.dtype).min ) else: attention_mask = causal_mask attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) attn_weight = None else: attn_output, attn_weight = self._attn(query, key, value, registered_causal_mask, attention_mask, head_mask) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(context_layer) outputs = (attn_output, present) if output_attentions: if self.use_flash_attn and not self.is_fp32: raise ValueError("Cannot output attentions while using flash-attn") else: outputs += (attn_weight,) return outputs class QwenModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): super().__init__(config, model, model_kwargs) self.original_fp16 = model.config.fp16 self.original_bf16 = model.config.bf16 model.config.bf16 = False model.config.fp16 = False if self.original_fp16 or self.original_bf16: # GPTQ models does to support casting to dtype try: model.to(torch.float32) except Exception: pass model.transformer.rotary_emb(2048) def __enter__(self): super().__enter__() max_positions = self._model.config.seq_length for block in self._model.transformer.h: block.attn._orig_forward = block.attn.forward # For performance, using constant tril to generate causal_mask block.attn.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), persistent=False, ) block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for block in self._model.transformer.h: block.attn.forward = block.attn._orig_forward self._model.config.bf16 = self.original_bf16 self._model.config.fp16 = self.original_fp16 def _baichuan13b_atten_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() proj = self.W_pack(hidden_states) proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] if past_key_value is not None: # reuse k, v, self_attention if attention_mask is not None: attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :] key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) if not output_attentions: past_key_value = (key_states, value_states) if use_cache else None attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) attn_weights = None else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: if q_len == 1: # inference with cache if len(attention_mask.size()) == 4: attention_mask = attention_mask[:, :, -1:, :] else: attention_mask = attention_mask[:, -1:, :] attn_weights = attn_weights + attention_mask attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value # Adapted from https://huggingface.co/baichuan-inc/Baichuan-7B/blob/262c8cb58b6d3615c208d9230baa869fddee2adb/modeling_baichuan.py#L181 def _baichuan7b_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed bsz, q_len, _ = hidden_states.size() proj = self.W_pack(hidden_states) proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None if not output_attentions: attn_weights = None attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, scale=1 / math.sqrt(self.head_dim) ) else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value class BaichuanModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): super().__init__(config, model, model_kwargs) # model has first inference buffers initialization if hasattr(self._model.lm_head, "first_flag"): self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) def __enter__(self): super().__enter__() # override signature to have position_ids if "position_ids" not in inspect.signature(self._model.forward).parameters: self._model._orig_forward = self._model.forward def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, position_ids: Optional[torch.LongTensor] = None, ): return self._orig_forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=past_key_values is not None, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=self.config.return_dict, ) self._model.forward = types.MethodType(forward, self._model) for layer in self._model.model.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn) else: for layer in self._model.model.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_baichuan7b_attn_forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model, "_orig_forward"): self._model.forward = self._model._orig_forward for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward def _mpt_sdpa_attention_forward( self, hidden_states: torch.Tensor, position_bias: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, ): batch_size, seq_length = hidden_states.shape[:2] mixed_qkv = self.Wqkv(hidden_states) query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: if len(past_key_value) != 0: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) else: past_key_value = (key_states, value_states) key_length = key_states.shape[-2] query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] attention_mask_sdpa = torch.ones( (query_states.shape[0], query_states.shape[1], query_states.shape[2], key_states.shape[2]), dtype=query_states.dtype, ) if position_bias is not None: position_bias_query_index = max(0, position_bias.size(1) - query_length) position_bias_key_index = max(0, position_bias.size(2) - key_length) position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] attention_mask_sdpa += position_bias attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min) context_states = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask_sdpa, dropout_p=self.attn_dropout_p, scale=self.softmax_scale, ) context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) attn_output = self.out_proj(context_states) return attn_output, None, past_key_value def _mpt_block_forward( self, hidden_states: torch.Tensor, position_bias: torch.Tensor, attention_mask: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, output_attentions: bool = False, ): # hidden_states: [batch_size, seq_length, hidden_size] # Layer norm at the beginning of the transformer layer. layernorm_output = self.norm_1(hidden_states) residual = hidden_states if not output_attentions: # Self attention. attn_outputs, attn_weights, past_key_value = self.attn( layernorm_output, position_bias=position_bias, attention_mask=attention_mask, past_key_value=layer_past, ) else: attn_outputs, attn_weights, past_key_value = self.attn._orig_forward( layernorm_output, position_bias=position_bias, attention_mask=attention_mask, past_key_value=layer_past, ) hidden_states = self.resid_attn_dropout(attn_outputs) + residual layernorm_output = self.norm_2(hidden_states) # Get residual residual = hidden_states # MLP. output = self.ffn(layernorm_output, residual) outputs = (output,) if use_cache: outputs += (past_key_value,) if output_attentions: outputs += (attn_weights,) return outputs class MPTModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_torch_version(">=", "2.1.0"): for block in self._model.transformer.blocks: block._orig_forward = block.forward block.forward = types.MethodType(_mpt_block_forward, block) block.attn._orig_forward = block.attn.forward block.attn.forward = types.MethodType(_mpt_sdpa_attention_forward, block.attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for block in self._model.transformer.blocks: if hasattr(block, "_orig_forward"): block.forward = block._orig_forward if hasattr(block.attn, "_orig_forward"): block.attn.forward = block.attn._orig_forward def _internlm2_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from einops import rearrange def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors.""" if position_ids is not None: cos = cos[position_ids] sin = sin[position_ids] cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) bsz, q_len, _ = hidden_states.size() qkv_states = self.wqkv(hidden_states) qkv_states = rearrange( qkv_states, "b q (h gs d) -> b q h gs d", gs=2 + self.num_key_value_groups, d=self.head_dim, ) query_states = qkv_states[..., : self.num_key_value_groups, :] query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") key_states = qkv_states[..., -2, :] value_states = qkv_states[..., -1, :] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) kv_seq_len = key_states.shape[-2] is_legacy = not hasattr(self, "layer_idx") if is_legacy: if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None else: cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": kwargs.get("cache_position")} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) if not output_attentions: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) ) attn_weights = None else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.wo(attn_output) return attn_output, attn_weights, past_key_value class InternLM2Patcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_torch_version(">=", "2.1.0"): for block in self._model.model.layers: block.attention._orig_forward = block.attention.forward block.attention.forward = types.MethodType(_internlm2_attention_forward, block.attention) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for block in self._model.model.layers: if hasattr(block.attention, "_orig_forward"): block.attention.forward = block.attention._orig_forward def phi3_442_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, sliding_window=self.config.sliding_window, ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L729 def _phi3_self_attn_sdpa_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: return self._orig_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) if is_transformers_version(">=", "4.41.0"): from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb, repeat_kv else: from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states) query_pos = self.num_heads * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value class Phi3ModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() # currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues if self._model.config.max_position_embeddings != getattr( self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings ): self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"): self._model.model._orig_forward = self._model.model.forward self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model) # https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113 # init inv_freq for torchscript tracing # 4.48 transformers version phi3 fixed, but issue still visible with trust_remote_true=True (trust_remote_code has _support_sdpa = False) for layer in self._model.model.layers: if ( is_torch_version(">=", "2.1.0") and is_transformers_version("<", "4.48.0") or not getattr(self._model, "_supports_sdpa", False) ): orig_self_attn_fwd = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn) layer.self_attn._orig_forward = orig_self_attn_fwd if ( hasattr(layer.self_attn, "rotary_emb") and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None ): rotary_emb = layer.self_attn.rotary_emb layer.self_attn.rotary_emb.inv_freq = 1.0 / ( rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.model, "_orig_forward"): self._model.model.forward = self._model.model._orig_forward if hasattr(self._model.model, "_orig_update_causal_mask"): self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward # Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756 # removed usage nonfriendly for tracing operation continue def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: from transformers.models.phimoe.modeling_phimoe import sparsemixer batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.input_jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_( 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise ) hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) routing_weights, selected_experts = sparsemixer( router_logits, jitter_eps=self.router_jitter_noise, training=self.training, ) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # if top_x.shape[0] == 0: # continue # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class PhiMoEModelPatcher(Phi3ModelPatcher): def __enter__(self): super().__enter__() for layer in self._model.model.layers: layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward layer.block_sparse_moe.forward = types.MethodType( _phi_moe_sparse_moe_block_forward, layer.block_sparse_moe ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: layer.block_sparse_moe.forward = layer.block_sparse_moe._orig_forward def _aquila_self_attn_sdpa_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed if output_attentions: return self._orig_forward( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache ) bsz, q_len, _ = hidden_states.size() if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view( bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None if hasattr(self, "num_key_value_groups"): # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) ) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value class AquilaModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() for layer in self._model.model.layers: if is_torch_version(">=", "2.1.0"): orig_self_attn_fwd = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_aquila_self_attn_sdpa_forward, layer.self_attn) layer.self_attn._orig_forward = orig_self_attn_fwd def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward def _xverse_self_attn_sdpa_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed if output_attentions: return self._orig_forward( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) ) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value def _internlm_self_attn_sdpa_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed if output_attentions: return self._orig_forward( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) ) attn_weights = None attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value class XverseModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() for layer in self._model.model.layers: if is_torch_version(">=", "2.1.0"): orig_self_attn_fwd = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_xverse_self_attn_sdpa_forward, layer.self_attn) layer.self_attn._orig_forward = orig_self_attn_fwd def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward class InternLMModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() for layer in self._model.model.layers: if is_torch_version(">=", "2.1.0"): orig_self_attn_fwd = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_internlm_self_attn_sdpa_forward, layer.self_attn) layer.self_attn._orig_forward = orig_self_attn_fwd def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward # Adapted from https://github.com/huggingface/optimum/blob/3adbe7c75e3c41c1a3b945cf085e74ece7f8e192/optimum/bettertransformer/models/attention.py#L234 def codegen_wrapped_scaled_dot_product( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ): batch_size = query.shape[0] mask_value = torch.finfo(value.dtype).min mask_value = torch.full([], mask_value, dtype=value.dtype) # in codegen the query and key are always in fp32 regardless of the dtype of the model # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 query = query.to(value.dtype) key = key.to(value.dtype) dropout_p = self.dropout_prob_attn if self.training else 0.0 if batch_size == 1 or self.training: if query.shape[2] > 1: # first step of the decoding sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: # in this case, which is the later decoding steps, the `causal_mask` in # https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195 # is [True, ..., True] so actually not causal sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False ) else: query_length, key_length = query.size(-2), key.size(-2) # causal_mask is always [True, ..., True] otherwise, so executing this # is unnecessary if query_length > 1: if not is_transformers_version(">", "4.44.99"): causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to( torch.bool ) causal_mask = torch.where(causal_mask, 0, mask_value) # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) # we use torch.min to avoid having tensor(-inf) attention_mask = torch.min(causal_mask, attention_mask) else: attention_mask = attention_mask[:, :, :, : key.shape[-2]] sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) return sdpa_result, None # copied from https://github.com/huggingface/optimum/blob/2112e99122d7f23a1da1a9d263fef64301050ea7/optimum/bettertransformer/models/attention.py#L168 # for preserving backward compatibility between outdated codegen remote code and new transformers def _codegen_wrapped_scaled_dot_product_legacy( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ): if head_mask is not None: raise ValueError("`head_mask` input argument is not supported") batch_size = query.shape[0] mask_value = torch.finfo(value.dtype).min mask_value = torch.full([], mask_value, dtype=value.dtype) if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1: raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.") # in codegen the query and key are always in fp32 regardless of the dtype of the model # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226 query = query.to(value.dtype) key = key.to(value.dtype) dropout_p = self.dropout_prob_attn if self.training else 0.0 if batch_size == 1 or self.training: if query.shape[2] > 1: # first step of the decoding sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: # in this case, which is the later decoding steps, the `causal_mask`` in # https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195 # is [True, ..., True] so actually not causal sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False ) else: query_length, key_length = query.size(-2), key.size(-2) # causal_mask is always [True, ..., True] otherwise, so executing this is unnecessary if query_length > 1: causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = torch.where(causal_mask, 0, mask_value) # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) # we use torch.min to avoid having tensor(-inf) attention_mask = torch.min(causal_mask, attention_mask) sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) return sdpa_result, None class CodeGenModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() attn_fn = codegen_wrapped_scaled_dot_product if is_torch_version(">=", "2.1.0") and is_transformers_version(">=", "4.45"): # in transformers 4.45 causal_mask const buffer was removed from the model # if it still exists, it means legacy remote code was loaded if hasattr(self._model.transformer.h[0].attn, "causal_mask"): attn_fn = _codegen_wrapped_scaled_dot_product_legacy for layer in self._model.transformer.h: if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions: orig_self_attn_fwd = layer.attn._attn layer.attn._attn = types.MethodType(attn_fn, layer.attn) layer.attn._orig_attn = orig_self_attn_fwd patch_update_causal_mask(self._model, "4.45.0", "transformer") def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model, "transformer") for layer in self._model.transformer.h: if hasattr(layer.attn, "_orig_attn"): layer.attn._attn = layer.attn._orig_attn # Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763 def _dbrx_experts_forward( self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor ): bsz, q_len, hidden_size = x.shape x = x.view(-1, hidden_size) out = torch.zeros_like(x) expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) # Chunk experts at once to avoid storing full parameter multiple times in autograd w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( self.moe_num_experts, dim=0 ) v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( self.moe_num_experts, dim=0 ) w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk( self.moe_num_experts, dim=0 ) w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked] v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] for expert_idx in range(0, self.moe_num_experts): topk_idx, token_idx = torch.where(expert_mask[expert_idx]) # Difference with original: removal # if token_idx.shape[0] == 0: # continue # loop interruption depends on input data and may affect torchscript tracing token_list = token_idx topk_list = topk_idx expert_tokens = x[None, token_list].reshape(-1, hidden_size) expert_out = ( self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx]) * top_weights[token_list, topk_list, None] ) out.index_add_(0, token_idx, expert_out) out = out.reshape(bsz, q_len, hidden_size) return out # Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228 def _dbrx_update_causal_mask_legacy( self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor ) -> Optional[torch.Tensor]: from transformers.modeling_attn_mask_utils import AttentionMaskConverter if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None dtype, device = input_tensor.dtype, input_tensor.device # difference with original modeling # using minimum from dtype with larger bandwith (floa32) may lead to overflow # during execution on platforms with default lower precision (bfloat16, float16) min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 ) # difference with original modeling # removed target_length = int(target_length). # Casting to int leads to constant folding during tracing that makes impossible to use model for sequence of different length causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: offset = cache_position[0] else: offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype causal_mask[ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] ] = mask_slice if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" ): # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). is_tracing = ( torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) if not is_tracing and torch.any(attention_mask != 1): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask # adopted from https://github.com/huggingface/transformers/blob/1b3dba9417eebe16b7c206d1dfca6a4c7f11dbec/src/transformers/models/dbrx/modeling_dbrx.py#L1204 def _dbrx_update_causal_mask_latest( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: "Cache", output_attentions: bool, ): from transformers.cache_utils import StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device # difference with original modeling # using minimum from dtype with larger bandwith (floa32) may lead to overflow # during execution on platforms with default lower precision (bfloat16, float16) min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: # difference with original modeling causal_mask = ( torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask if is_transformers_version(">", "4.40.2"): _dbrx_update_causal_mask = _dbrx_update_causal_mask_latest else: _dbrx_update_causal_mask = _dbrx_update_causal_mask_legacy class DBRXModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() # dbrx has some accuracy issues with bf16 with transformers >= 4.40 # fill causal mask in slightly different way for avoid overflow on some platforms self._model.transformer._orig_update_causal_mask = self._model.transformer._update_causal_mask self._model.transformer._update_causal_mask = types.MethodType( _dbrx_update_causal_mask, self._model.transformer ) # starting from transformers 4.41 issue also observable for calculation sin/cos for rotary_emb patch_rope_sin_cos = is_transformers_version(">=", "4.41.0") inv_freq = getattr(self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb, "inv_freq") dim, base = None, None if inv_freq is None: dim = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.dim base = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.base max_positions = self._model.config.max_seq_len if patch_rope_sin_cos: embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq) for block in self._model.transformer.blocks: rotary_emb = block.norm_attn_norm.attn.rotary_emb # initialize inv_freq for torchscript tracing if rotary_emb.inv_freq is None: inv_freq = 1.0 / ( rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) rotary_emb.inv_freq = inv_freq if patch_rope_sin_cos: rotary_emb.register_buffer("embed_positions", embed_positions) rotary_emb._orig_forward = rotary_emb.forward rotary_emb.forward = types.MethodType(llama_gemma_rotary_emb_forward, rotary_emb) # remove continue-operator from iteration loop over experts block.ffn.experts._orig_forward = block.ffn.experts.forward block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask for block in self._model.transformer.blocks: block.ffn.experts.forward = block.ffn.experts._orig_forward if hasattr(block.norm_attn_norm.attn.rotary_emb, "_orig_forward"): block.norm_attn_norm.attn.rotary_emb.forward = block.norm_attn_norm.attn.rotary_emb._orig_forward # Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264 def _persimmon_self_attn_sdpa_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional["Cache"] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: from transformers.models.persimmon.modeling_persimmon import apply_rotary_pos_emb if output_attentions: return self._orig_forward( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache ) bsz, q_len, _ = hidden_states.size() # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # 3 x [batch_size, seq_length, num_heads, head_dim] (query_states, key_states, value_states) = self._split_heads(fused_qkv) if self.qk_layernorm: query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] query_states = query_states.transpose(1, 2) value_states = value_states.transpose(1, 2) key_states = key_states.transpose(1, 2) if is_transformers_version("<", "4.44.99"): kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) else: if position_embeddings is None: log.warning( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings if is_transformers_version("<", "4.44.99"): rotary_ndims = self.rotary_emb.dim else: rotary_ndims = self.rotary_ndims # Partial rotary embedding query_rot, query_pass = ( query_states[..., :rotary_ndims], query_states[..., rotary_ndims:], ) key_rot, key_pass = ( key_states[..., :rotary_ndims], key_states[..., rotary_ndims:], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: # Specific to RoPE models with partial rotation cache_kwargs = { "sin": sin, "cos": cos, "partial_rotation_size": rotary_ndims, "cache_position": cache_position, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, causal_mask, scale=1 / math.sqrt(self.head_dim), dropout_p=self.attention_dropout.p, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.dense(attn_output) return attn_output, None, past_key_value class PersimmonModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() patch_update_causal_mask(self._model, "4.42.0") for layer in self._model.model.layers: if is_torch_version(">=", "2.1.0"): orig_self_attn_fwd = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_persimmon_self_attn_sdpa_forward, layer.self_attn) layer.self_attn._orig_forward = orig_self_attn_fwd if is_transformers_version("<", "4.44.99"): _reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model) for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward def _jais_attn_forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, position_bias: Optional[torch.FloatTensor] = None, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `JAISAttention(..., is_cross_attention=True)`." ) query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache is True: present = (key, value) else: present = None if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn( query, key, value, attention_mask, head_mask, position_bias ) else: # Difference with original: override attn realization with sdpa if not output_attentions if not output_attentions: attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, position_bias) else: attn_output, attn_weights = self._orig_attn(query, key, value, attention_mask, head_mask, position_bias) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: outputs += (attn_weights,) return outputs def _jais_attn(self, query, key, value, attention_mask=None, head_mask=None, position_bias=None): scale = 1.0 if self.scale_attn_weights: scale = 1 / self.head_dim**self.attn_scale_power # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: scale = scale / float(self.layer_idx + 1) query_length = query.size(-2) attention_mask_sdpa = torch.ones( (query.shape[0], query.shape[1], query.shape[2], key.shape[2]), dtype=query.dtype, ) if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(torch.float16).min attention_mask_sdpa.masked_fill_(~causal_mask, mask_value) if attention_mask is not None: # Apply the attention mask attention_mask_sdpa = attention_mask_sdpa + attention_mask if position_bias is not None: attention_mask_sdpa += position_bias.type_as(attention_mask_sdpa).unsqueeze(0) # Mask heads if we want to if head_mask is not None: attention_mask_sdpa = attention_mask_sdpa * head_mask attn_output = F.scaled_dot_product_attention( query, key, value, attention_mask_sdpa, dropout_p=self.attn_dropout.p, scale=scale ) return attn_output, None class JaisModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() for layer in self._model.transformer.h: if is_torch_version(">=", "2.1.0"): orig_self_attn_fwd = layer.attn._attn layer.attn._attn = types.MethodType(_jais_attn, layer.attn) layer.attn._orig_attn = orig_self_attn_fwd layer.attn._orig_forward = layer.attn.forward layer.attn.forward = types.MethodType(_jais_attn_forward, layer.attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.transformer.h: if hasattr(layer.attn, "_orig_attn"): layer.attn._attn = layer.attn._orig_attn layer.attn.forward = layer.attn._orig_forward class UpdateCausalMaskModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() patch_update_causal_mask(self._model, "4.42.0") if ( hasattr(self._model, "model") and hasattr(self._model.model, "layers") and hasattr(self._model.model.layers[0].self_attn, "rotary_emb") and hasattr(self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache") ): for layer in self._model.model.layers: _reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model) class RotaryEmbPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version("<", "4.44.99"): for layer in self._model.model.layers: _reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb) # Adapted from https://github.com/huggingface/transformers/blob/31f9a289a6207be6cae746e009d8e0db523be203/src/transformers/models/falcon/modeling_falcon.py#L1138 def _falcon_prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, ): if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: # different from original: allow to provide min_dtype as parameter min_dtype = torch.finfo(dtype).min if "min_dtype" not in kwargs else kwargs["min_dtype"] causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask def _falcon_update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: "Cache", output_attentions: bool, head_mask: torch.Tensor, alibi: torch.Tensor, ): # copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130 from transformers.cache_utils import StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions and head_mask is None and alibi is None ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device # difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution min_dtype = torch.finfo(torch.float16).min batch_size, sequence_length, _ = input_tensor.shape if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _falcon_prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) # We take care to integrate alibi bias in the causal_mask here if head_mask is None and alibi is not None: alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) causal_mask = torch.masked_fill( alibi / math.sqrt(self.config.hidden_size // self.num_heads), causal_mask < -1, min_dtype, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask class FalconModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version("<", "4.44.99"): for layer in self._model.transformer.h: _reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb) else: patch_update_causal_mask(self._model, "4.45.0", "transformer", _falcon_update_causal_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model, "transformer") class GptNeoxModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version("<", "4.44.99"): for layer in self._model.gpt_neox.layers: _reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb) else: patch_update_causal_mask(self._model, "4.45.0", "gpt_neox") def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model, "gpt_neox") # Adopted from https://github.com/huggingface/optimum/blob/v1.24.0/optimum/bettertransformer/models/attention.py#L96 def _gptj_attn(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: return self._orig_attn(query, key, value, attention_mask, head_mask) batch_size = query.shape[0] mask_value = torch.finfo(value.dtype).min mask_value = torch.full([], mask_value, dtype=value.dtype) # in gpt-neo-x and gpt-j the query and keys are always in fp32 # thus we need to cast them to the value dtype if getattr(self, "downcast_qk", False): query = query.to(value.dtype) key = key.to(value.dtype) if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1: return self._orig_attn(query, key, value, attention_mask, head_mask) dropout_p = self.dropout_prob_attn if self.training else 0.0 if batch_size == 1 or self.training: if query.shape[2] > 1: sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False ) else: query_length, key_length = query.size(-2), key.size(-2) # causal_mask is always [True, ..., True] otherwise, so executing this # is unnecessary if query_length > 1: if not is_transformers_version(">=", "4.44.99"): causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = torch.where(causal_mask, 0, mask_value) # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) if attention_mask is not None: attention_mask = causal_mask + attention_mask else: attention_mask = attention_mask[:, :, :, : key.shape[-2]] sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) # in gpt-neo-x and gpt-j the query and keys are always in fp32 # thus we need to cast them to the value dtype if getattr(self, "downcast_qk", False): sdpa_result = sdpa_result.to(value.dtype) return sdpa_result, None def gptj_attn_forward( self, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ): if output_attentions: self._attn = self._orig_attn return self._orig_forward( hidden_states, layer_past, attention_mask, position_ids, head_mask, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) class GptJModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() patch_update_causal_mask(self._model, "4.45.0", "transformer") if is_transformers_version(">=", "4.49"): self._model.config._orig_attn_implementation = self._model.config._attn_implementation self._model.config._attn_implementation = "sdpa" for block in self._model.transformer.h: block.attn._orig_forward = block.attn.forward block.attn.forward = types.MethodType(gptj_attn_forward, block.attn) block.attn._orig_attn = block.attn._attn block.attn._attn = types.MethodType(_gptj_attn, block.attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model, "transformer") if is_transformers_version(">=", "4.49"): self._model.config._attn_implementation = self._model.config._orig_attn_implementation for block in self._model.transformer.h: block.attn.forward = block.attn._orig_forward block.attn._attn = block.attn._orig_attn class GptNeoxJapaneseModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version("<", "4.44.99"): for layer in self._model.gpt_neox_japanese.layers: _reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb) else: patch_update_causal_mask(self._model, "4.45.0", "gpt_neox_japanese") def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model, "gpt_neox_japanese") # Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L721 def _bloom_attn_forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past=None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ): from transformers.models.bloom.modeling_bloom import dropout_add if head_mask is not None or output_attentions: return self._orig_forward( hidden_states, residual, alibi, attention_mask, layer_past=layer_past, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) batch_size, q_length, _ = hidden_states.shape # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # 3 x [batch_size, num_heads, seq_length, head_dim] query_layer, key_layer, value_layer = self._reshape(fused_qkv) if layer_past is not None: cache_kwargs = {"cache_position": cache_position} key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) if attention_mask is not None: # no matter the length, we just slice it kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1 causal_mask = attention_mask[:, :, :, :kv_length] alibi = torch.masked_fill(alibi, causal_mask.bool(), torch.finfo(alibi.dtype).min) context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=alibi, dropout_p=self.dropout_prob_attn if self.training else 0.0, ) # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] context_layer = context_layer.transpose(1, 2) context_layer = context_layer.reshape(batch_size, q_length, self.hidden_size) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( context_layer[:, :, int(i * slices) : int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) outputs = (output_tensor, layer_past) return outputs class BloomModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.49.0"): self._model.config._orig_attn_implementation = self._model.config._attn_implementation self._model.config._attn_implementation = "sdpa" for block in self._model.transformer.h: block.self_attention._orig_forward = block.self_attention.forward block.self_attention.forward = types.MethodType(_bloom_attn_forward, block.self_attention) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if is_transformers_version(">=", "4.49.0"): self._model.config._attn_implementation = self._model.config._orig_attn_implementation for block in self._model.transformer.h: block.self_attention.forward = block.self_attention._orig_forward def _gpt_neo_attn_forward( self, hidden_states, attention_mask=None, layer_past=None, head_mask=None, use_cache=False, output_attentions=False, cache_position=None, ): if output_attentions: self._attn = self._orig_attn return self._orig_forward( hidden_states, attention_mask=attention_mask, layer_past=layer_past, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) # Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185 def _gpt_neo_attn_sdpa( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ): batch_size = query.shape[0] mask_value = torch.finfo(torch.float16).min mask_value = torch.full([], mask_value, dtype=value.dtype) dropout_p = float(self.config.attention_dropout) if self.training else 0.0 if (batch_size == 1 or self.training) and self.attention_type == "global": if query.shape[2] > 1: sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=1.0 ) else: query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] causal_mask = torch.where(causal_mask, 0, mask_value) if batch_size > 1: # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) if attention_mask is not None: attention_mask = causal_mask + attention_mask sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False, scale=1.0 ) return sdpa_result, None class GptNeoModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"): self._model.config._orig_attn_implementation = self._model.config._attn_implementation self._model.config._attn_implementation = "sdpa" for layer in self._model.transformer.h: self_attn = layer.attn.attention self_attn._orig_attn = self_attn._attn self_attn._attn = types.MethodType(_gpt_neo_attn_sdpa, self_attn) self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.config, "_orig_attn_implementation"): self._model.config._attn_implementation = self._model.config._orig_attn_implementation for layer in self._model.transformer.h: for layer in self._model.transformer.h: layer.attn.attention.forward = layer.attn.attention._orig_forward layer.attn.attention._attn = layer.attn.attention._orig_attn class Gemma2ModelPatcher(LlamaModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): from transformers.cache_utils import DynamicCache signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) return_legacy_cache = False pkv_in_args = False legacy_pkv = None if "past_key_values" in kwargs: legacy_pkv = kwargs.pop("past_key_values", None) sign_names = list(signature.parameters.keys()) pkv_argument_index = sign_names.index("past_key_values") cache_position_index = sign_names.index("cache_position") if "cache_position" in sign_names else -1 input_ids_index = sign_names.index("input_ids" if "input_ids" in sign_names else "inputs_embeds") if legacy_pkv is None and len(args) > pkv_argument_index: legacy_pkv = args[pkv_argument_index] pkv_in_args = True if legacy_pkv is not None: pkv = DynamicCache.from_legacy_cache(legacy_pkv) return_legacy_cache = True if not pkv_in_args: kwargs["past_key_values"] = pkv else: args[pkv_argument_index] = pkv if ( return_legacy_cache and cache_position_index != -1 and (cache_position_index > len(args) and "cache_position" not in kwargs) ): past_seen_tokens = legacy_pkv[0][0].shape[-2] input_ids = args[input_ids_index] if "input_ids" not in kwargs else kwargs["input_ids"] cache_position = torch.arange( past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device ) kwargs["cache_position"] = cache_position outputs = self.orig_forward(*args, **kwargs) if return_legacy_cache: outputs.past_key_values = outputs.past_key_values.to_legacy_cache() return outputs self.patched_forward = patched_forward def _decilm_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # decilm contains bug in attention calculation for case if past key values is not None def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) bsz, q_len, _ = hidden_states.size() if self.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, is_causal=attention_mask is None, attn_mask=attention_mask ) # modified, in original implementation .transpose(1, 2) missed attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) if self.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) attn_weights = None return attn_output, attn_weights, past_key_value class DeciLMModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() for layer in self._model.model.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_decilm_attn_forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: layer.self_attn.forward = layer.self_attn._orig_forward class IBertModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): super().__init__(config, model, model_kwargs) if getattr(self._model, "ibert"): embeddings = self._model.ibert.embeddings else: embeddings = self._model.embeddings # model has first inference buffers initialization, it may breaks tracing if getattr(embeddings.LayerNorm, "dim_sqrt") is None: self._model(torch.ones([1, 1], dtype=torch.long)) class InternVLChatImageEmbeddingModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = model.extract_feature if model.vision_model.encoder.layers[0].attn.use_flash_attn: for layer in model.vision_model.encoder.layers: layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn layer.attn.use_flash_attn = False super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"): for layer in self._model.vision_model.encoder.layers: layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn class InternVL2ChatLangModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any] ): model_type = model.config.model_type patcher_for_model_type = { "llama": LlamaModelPatcher, "qwen2": UpdateCausalMaskModelPatcher, "phi3": Phi3ModelPatcher, "internlm2": InternLM2Patcher, } self._internal_patcher = None self._patched_forward = None internal_patcher_cls = patcher_for_model_type.get(model_type) if internal_patcher_cls is not None: self._internal_patcher = internal_patcher_cls(config, model, model_kwargs) self._patched_forward = self._internal_patcher.patched_forward super().__init__(config, model, model_kwargs) @property def patched_forward(self): if self._internal_patcher is not None: return self._internal_patcher.patched_forward return self._patched_forward @patched_forward.setter def patched_forward(self, fn): self._patched_forward = fn if self._internal_patcher is not None: self._internal_patcher.patched_forward = fn def __enter__(self): if is_torch_version(">=", "2.1.0"): if ( self._model.config.model_type in ["qwen2", "llama"] and self._model.config._attn_implementation != "sdpa" ): self._model.config._orig_attn_implementation = self._model.config._attn_implementation self._model.config._attn_implementation = "sdpa" if self._model.config.model_type == "qwen2" and is_transformers_version("<", "4.48"): from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"] for layer in self._model.model.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) if self._model.config.model_type == "llama" and is_transformers_version("<", "4.47"): from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"] for layer in self._model.model.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) if self._internal_patcher is not None: return self._internal_patcher.__enter__() return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): if self._internal_patcher: self._internal_patcher.__exit__(exc_type, exc_value, traceback) else: super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.config, "_orig_attn_implementation"): self._model.config._attn_implementation = self._model.config._orig_attn_implementation for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward def llava_vision_embed_forward(self, pixel_values): # copied from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L428-L441 # these changes does not bring any difference from original, it only packs model subcomponent inference together # that allow us avoid memory overheads and their inference results handling on code-level image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. selected_image_feature = image_outputs.hidden_states[self.config.vision_feature_layer] if self.config.vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif self.config.vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature else: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") image_features = self.multi_modal_projector(selected_image_feature) return image_features def llava_next_video_vision_embed_forward(self, pixel_values): # copied from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/llava_next_video/modeling_llava_next_video.py#L519 # these changes does not bring any difference from original, it only packs model subcomponent inference together # that allow us avoid memory overheads and their inference results handling on code-level image_features = self.vision_tower(pixel_values, output_hidden_states=True) vision_feature_layer = self.config.vision_feature_layer if isinstance(vision_feature_layer, int): selected_image_feature = image_features.hidden_states[vision_feature_layer] else: hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer] selected_image_feature = torch.cat(hs_pool, dim=-1) if self.config.vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif self.config.vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature else: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") return selected_image_feature # Modified from https://huggingface.co/microsoft/maira-2/blob/main/modeling_maira2.py#L68 def maira_vision_embed_forward(self, pixel_values): vision_feature_select_strategy = self.config.vision_feature_select_strategy vision_feature_layer = self.config.vision_feature_layer return self.get_image_features(pixel_values, vision_feature_layer, vision_feature_select_strategy) class LlavaImageEmbeddingModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = types.MethodType(llava_vision_embed_forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward class MairaImageEmbeddingModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = types.MethodType(maira_vision_embed_forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward class LlavaNextVideoImageEmbeddingModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = types.MethodType(llava_next_video_vision_embed_forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward def _embednb_forward(self, ids: torch.Tensor) -> torch.Tensor: def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0) cos_out = torch.cos(out) sin_out = torch.sin(out) stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) return out.float() n_axes = ids.shape[-1] emb = torch.cat( [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(1) class FluxTransfromerModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() if is_diffusers_version("<", "0.31.0"): self._model.pos_embed._orig_forward = self._model.pos_embed.forward self._model.pos_embed.forward = types.MethodType(_embednb_forward, self._model.pos_embed) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.pos_embed, "_orig_forward"): self._model.pos_embed.forward = self._model.pos_embed._orig_forward def _minicpmv_resampler_forward(self, image_feature, pos_embed, key_padding_mask): bs = image_feature.shape[0] image_feature = self.kv_proj(image_feature) # B * L * D image_feature = self.ln_kv(image_feature).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D q_bs = q.unsqueeze(1).repeat(1, bs, 1) out = self.attn(q_bs, image_feature + pos_embed, image_feature, key_padding_mask=key_padding_mask)[ 0 ] # Q * B * D # L * B * D + L * B * D # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D x = self.ln_post(x) x = x @ self.proj return x def _minicpmv_siglip_vis_embed_forward( self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor] = None, position_ids: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) if position_ids is None: batch_size = pixel_values.size(0) max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full( size=( batch_size, max_nb_patches_h * max_nb_patches_w, ), fill_value=0, ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): if tgt_sizes is not None: nb_patches_h = tgt_sizes[batch_idx][0] nb_patches_w = tgt_sizes[batch_idx][1] else: nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings def _minicpmv_siglip_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, is_causal=attention_mask is None ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None def _minicpmv_siglip_transformer_forward( self, pixel_values, patch_attention_mask: Optional[torch.BoolTensor] = None, tgt_sizes: Optional[torch.IntTensor] = None, position_ids: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = pixel_values.size(0) if patch_attention_mask is None: patch_attention_mask = torch.ones( size=( batch_size, pixel_values.size(2) // self.config.patch_size, pixel_values.size(3) // self.config.patch_size, ), dtype=torch.bool, device=pixel_values.device, ) hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, position_ids=position_ids, ) patch_attention_mask = patch_attention_mask.view(batch_size, -1) attention_mask = ( _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self._use_flash_attention_2 else patch_attention_mask ) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) if not return_dict: return (last_hidden_state, None) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=None, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class MiniCPMVResamplerModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = types.MethodType(_minicpmv_resampler_forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward class MiniCPMVImageEmbeddingsModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = types.MethodType(_minicpmv_siglip_transformer_forward, model) super().__init__(config, model, model_kwargs) def __enter__(self): super().__enter__() self._model.embeddings._orig_forward = self._model.embeddings.forward self._model.embeddings.forward = types.MethodType(_minicpmv_siglip_vis_embed_forward, self._model.embeddings) if is_torch_version(">=", "2.0.0"): for layer in self._model.encoder.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(_minicpmv_siglip_attn_forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward self._model.embeddings.forward = self._model.embeddings._orig_forward if is_torch_version(">=", "2.0.0"): for layer in self._model.encoder.layers: layer.self_attn.forward = layer.self_attn._orig_forward class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = model.encode_images super().__init__(config, model, model_kwargs) if not self._model.get_vision_tower().is_loaded: self._model.get_vision_tower().load_model() def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward class InputEmbeddingPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward def forward(self, input): return self.__orig_forward(input) model.forward = types.MethodType(forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor): return self.get_img_features(pixel_values) class Phi3VisionImageEmbeddingsPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward model.forward = types.MethodType(phi3_vision_embeddings_forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward def minicpm3_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value=None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ orig_dtype = k.dtype cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] q_fp32 = q.to(dtype=torch.float32, device=q.device) k_fp32 = k.to(dtype=torch.float32, device=k.device) q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) if output_attentions: return self._orig_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.shape q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(hidden_states.shape[0], hidden_states.shape[1], 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) # Difference with original code, k_pe.new_empty create constant tensor in torchscript query_states = torch.concat([q_nope, q_pe], dim=-1) # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1) # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value class MiniCPM3Patcher(DecoderModelPatcher): def __enter__(self): super().__enter__() for block in self._model.model.layers: block.self_attn._orig_forward = block.self_attn.forward block.self_attn.forward = types.MethodType(minicpm3_attn_forward, block.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for block in self._model.model.layers: block.self_attn.forward = block.self_attn._orig_forward class DeepseekPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() self_attn = { "deepseek_v3": deepseek_v3_attn_forward, "deepseek_v2": deepseek_v2_attn_forward, "deepseek": minicpm3_attn_forward, } self_attn_fwd = self_attn.get(self._model.config.model_type) for block in self._model.model.layers: if self_attn_fwd is not None: block.self_attn._orig_forward = block.self_attn.forward block.self_attn.forward = types.MethodType(self_attn_fwd, block.self_attn) if hasattr(block.mlp, "moe_infer"): block.mlp._org_moe_infer = block.mlp.moe_infer block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for block in self._model.model.layers: block.self_attn.forward = block.self_attn._orig_forward if hasattr(block.mlp, "_orig_moe_infer"): block.mlp.moe_infer = block.mlp._orig_moe_infer def deepseek_v3_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value=None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751 def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): orig_dtype = k.dtype cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] q_fp32 = q.to(dtype=torch.float32, device=q.device) k_fp32 = k.to(dtype=torch.float32, device=k.device) q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) if output_attentions: return self._orig_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) # Difference with original code, k_pe.new_empty create constant tensor in torchscript query_states = torch.concat([q_nope, q_pe], dim=-1) # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1) # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value def deepseek_v2_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value=None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # modified from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L806 def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) b, h, s, d = k.shape k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed if output_attentions: return self._orig_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.shape if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) # Difference with original code, k_pe.new_empty create constant tensor in torchscript query_states = torch.concat([q_nope, q_pe], dim=-1) # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1) # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value def deepseek_moe_infer(self, x, topk_ids, topk_weight): cnts = torch.zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0).to(torch.long) idxs = torch.argsort(topk_ids.view(-1)) sorted_tokens = x[idxs // topk_ids.shape[1]] outputs = [] start_idx = torch.tensor(0, dtype=torch.long) for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens # difference with original: removed skiping expert if empty num_tokens expert_id = i + self.ep_rank * self.experts_per_rank expert = self.experts[expert_id] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx # difference with original: removed usage torch.new_empty if outputs empty outs = torch.cat(outputs, dim=0) new_x = torch.zeros_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .to(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .to(new_x.dtype) ) return final_out class Qwen2VLLanguageModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any] = None, ): model.__orig_forward = model.forward def forward_wrap( self, attention_mask, position_ids=None, past_key_values=None, inputs_embeds=None, input_ids=None, use_cache=True, ): from transformers.cache_utils import DynamicCache new_past_key_values = DynamicCache.from_legacy_cache(past_key_values) result = self.__orig_forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=new_past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, ) if past_key_values is not None: result["past_key_values"] = result["past_key_values"].to_legacy_cache() return result model.forward = types.MethodType(forward_wrap, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward def patch_qwen2vl_vision_blocks(model, force_new_behaviour=False): if not force_new_behaviour and is_transformers_version("<=", "4.48.99"): # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390 # added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len) def sdpa_attn_forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if is_transformers_version(">=", "4.49"): if position_embeddings is None: emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos().float() sin = emb.sin().float() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) else: q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L430 # added attention_mask input propagation to self.attn def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states else: # Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L391 # added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len) def sdpa_attn_forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if position_embeddings is None: emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos().float() sin = emb.sin().float() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output # Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L446 # added attention_mask input propagation to self.attn def block_forward( self, hidden_states, attention_mask, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states for block in model.blocks: block._orig_forward = block.forward block.forward = types.MethodType(block_forward, block) block.attn._orig_forward = block.attn.forward block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn) class Qwen2VLVisionEmbMergerPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any] = None, ): model.__orig_forward = model.forward # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118 # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len) # separated patch_embed and rot_pos_emb calls for performing as part of another model def image_embed_forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor ) -> torch.Tensor: for blk in self.blocks: hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) return self.merger(hidden_states) model.forward = types.MethodType(image_embed_forward, model) super().__init__(config, model, model_kwargs) def __enter__(self): patch_qwen2vl_vision_blocks(self._model) super().__enter__() def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward for block in self._model.blocks: block.forward = block._orig_forward block.attn.forward = block.attn._orig_forward class Qwen2_5_VLVisionEmbMergerPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any] = None, ): super().__init__(config, model, model_kwargs) model.__orig_forward = model.forward # Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L405 # added attention_mask and window_attention_mask inputs instead cu_lens and window_cu_lens processing for its internal calculation model # (unsupported by tracing due to cycle with dynamic len) # separated patch_embed and rot_pos_emb calls for performing as part of another model def image_embed_forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, window_attention_mask: torch.Tensor, window_index: torch.Tensor, rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: seq_len = hidden_states.shape[0] hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: attention_mask_now = attention_mask else: attention_mask_now = window_attention_mask hidden_states = blk( hidden_states, attention_mask=attention_mask_now, position_embeddings=position_embeddings ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] return hidden_states model.forward = types.MethodType(image_embed_forward, model) super().__init__(config, model, model_kwargs) def __enter__(self): patch_qwen2vl_vision_blocks(self._model, force_new_behaviour=True) super().__enter__() def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward for block in self._model.blocks: block.forward = block._orig_forward block.attn.forward = block.attn._orig_forward # copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321 def _granite_moe_topk_gating_forward(self, hidden_states): # compute the top_k routing decision logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts] top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k] top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k] # compute number of input given to each expert zeros = torch.zeros( [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device ) # [num_tokens, num_experts] gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] expert_size = gates.long().sum(0) # [num_experts,] # difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing # sort and group input tokens according to expert assignment top_k_experts = top_k_indices.flatten() # [num_tokens * top_k] _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k] batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k] # gather the gate values for grouped input tokens top_k_gates = top_k_gates.flatten() # [num_tokens * top_k] batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k] return index_sorted_experts, batch_index, batch_gates, expert_size, logits # copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L281 def _granite_moe_parallel_experts_forward(self, inputs, expert_size): output_list = [] # difference with original # 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size) # 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle index_start = torch.tensor(0, dtype=torch.int64) for i in range(self.num_experts): next_index = index_start + expert_size[i] output_list.append(F.linear(inputs[index_start:next_index], self.weight[i])) index_start = next_index results = torch.cat(output_list, dim=0) return results class GraniteMoEModelPatcher(LlamaModelPatcher): def __enter__(self): super().__enter__() for layer in self._model.model.layers: block_sparse_moe = layer.block_sparse_moe block_sparse_moe.router._orig_forward = block_sparse_moe.router.forward block_sparse_moe.router.forward = types.MethodType( _granite_moe_topk_gating_forward, block_sparse_moe.router ) block_sparse_moe.input_linear._orig_forward = block_sparse_moe.input_linear.forward block_sparse_moe.input_linear.forward = types.MethodType( _granite_moe_parallel_experts_forward, block_sparse_moe.input_linear ) block_sparse_moe.output_linear._orig_forward = block_sparse_moe.output_linear.forward block_sparse_moe.output_linear.forward = types.MethodType( _granite_moe_parallel_experts_forward, block_sparse_moe.output_linear ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: block_sparse_moe = layer.block_sparse_moe block_sparse_moe.router.forward = block_sparse_moe.router._orig_forward block_sparse_moe.input_linear.forward = block_sparse_moe.input_linear._orig_forward block_sparse_moe.output_linear.forward = block_sparse_moe.output_linear._orig_forward # copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401 def gpt_bigcode_attn(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: # The super dispatch is done in the forward. raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.") scale = None if not self.scale_attn_weights: scale = 1 # MQA models: (batch_size, query_length, num_heads * head_dim) # MHA models: (batch_size, num_heads, query_length, head_dim) query_shape = query.shape batch_size = query_shape[0] key.shape[-2] if self.multi_query: query_length = query_shape[1] # SDPA requires the dimension [..., sequence_length, head_dim]. query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. key = key.unsqueeze(1) value = value.unsqueeze(1) # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend # and flash attention backend (No available kernel. Aborting execution.) from the shapes # query = [batch_size, num_heads, query_length, head_dim] # key = [batch_size, 1, past_length, head_dim] # value = [batch_size, 1, past_length, head_dim] # # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check. if is_torch_version(">=", "2.2.0"): key = key.expand(-1, self.num_heads, -1, -1) value = value.expand(-1, self.num_heads, -1, -1) else: query_length = query_shape[-1] # See the comment above. if query.device.type == "cuda" and attention_mask is not None: query = query.contiguous() key = key.contiguous() value = value.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not # create a causal mask in case query_length == 1. is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False # different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype if attention_mask is not None: attention_mask = attention_mask.to(query.dtype) sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=self.attn_pdrop if self.training else 0.0, is_causal=is_causal, scale=scale, ) if self.multi_query: # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) sdpa_result = sdpa_result.transpose(1, 2) # Reshape is kind of expensive here, as it does a memory copy, # but I did not manage to make away without it (logits do not match when using view) # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) sdpa_result = sdpa_result.reshape(query_shape) return sdpa_result, None class GptBigCodeModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa": for layer in self._model.transformer.h: layer.attn._orig_attn = layer.attn._attn layer.attn._attn = types.MethodType(gpt_bigcode_attn, layer.attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa": for layer in self._model.transformer.h: layer.attn._attn = layer.attn._orig_attn class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): model.__orig_forward = model.forward @functools.wraps(model.__orig_forward) def patched_forward(*args, **kwargs): from transformers.cache_utils import EncoderDecoderCache signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) return_legacy_cache = False pkv_in_args = False legacy_pkv = None if "past_key_values" in kwargs: legacy_pkv = kwargs.pop("past_key_values", None) sign_names = list(signature.parameters.keys()) pkv_argument_index = sign_names.index("past_key_values") if legacy_pkv is None and len(args) > pkv_argument_index: legacy_pkv = args[pkv_argument_index] pkv_in_args = True if legacy_pkv is not None: if isinstance(legacy_pkv, EncoderDecoderCache): legacy_pkv = legacy_pkv.to_legacy_cache() only_self_cache = [cache_item[:2] for cache_item in legacy_pkv] pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache) return_legacy_cache = True if not pkv_in_args: kwargs["past_key_values"] = pkv else: args[pkv_argument_index] = pkv outputs = model.__orig_forward(*args, **kwargs) if return_legacy_cache: outputs.past_key_values = outputs.past_key_values.to_legacy_cache() return outputs model.forward = patched_forward super().__init__(config, model, model_kwargs) class SanaTextEncoderModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True) if self._model.config._attn_implementation != "sdpa": self._model.config._orig_attn_implementation = self._model.config._attn_implementation self._model.config._attn_implementation = "sdpa" if is_transformers_version("<", "4.47.0"): from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"] for layer in self._model.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) unpatch_update_causal_mask(self._model, None, True) if hasattr(self._model.config, "_orig_attn_implementation"): self._model.config._attn_implementation = self._model.config._orig_attn_implementation for layer in self._model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward class MiniCPMModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): for layer in model.model.layers: if hasattr(layer, "scale_depth"): layer.self_attn.o_proj.to(torch.float32) layer.mlp.down_proj.to(torch.float32) super().__init__(config, model, model_kwargs) class CommonImageEmbeddingsModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835 # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321 if hasattr(model, "model") and hasattr(model.model, "get_image_features"): model.forward = model.model.get_image_features else: model.forward = model.get_image_features super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147 def _gemma3_mm_update_causal_mask( self, attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training: bool = False ): if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted # form and requires no inversion or slicing. return attention_mask min_dtype = torch.finfo(torch.float16).min inputs_lead_dim, sequence_length = input_tensor.shape[:2] target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[0] + sequence_length + 1 ) causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) # Apply bidirectional mask on images if token type ids are provided if token_type_ids is not None and sequence_length != 1: token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) token_type_mask[token_type_ids == 0] = False # if text token do not change anything token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) causal_mask = causal_mask.clone() causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( token_type_mask, 0.0 ) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] # Then apply padding mask (will mask pad tokens) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) return causal_mask class Gemma3LMModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): model.__orig_forward = model.forward if is_transformers_version("<", "4.52"): model._update_causal_mask_mm = types.MethodType(_gemma3_mm_update_causal_mask, model) else: model.model._orig_update_causual_mask = model.model._update_causal_mask model.model._update_causal_mask = types.MethodType(_gemma3_mm_update_causal_mask, model.model) # Difference from original: # uses Dynamic cache from legacy cache instead of HybridCache # calculate causal mask from multimodal def forward( self, attention_mask, position_ids, past_key_values, token_type_ids, inputs_embeds, use_cache=True ): from transformers.cache_utils import DynamicCache pkv = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values[0][0].shape[-2] cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) forward_kwargs = {} if is_transformers_version("<", "4.52"): attention_mask = self._update_causal_mask_mm( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds ) else: forward_kwargs["token_type_ids"] = token_type_ids result = self.__orig_forward( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, past_key_values=pkv, inputs_embeds=inputs_embeds, use_cache=use_cache, **forward_kwargs, ) upd_pkv = result["past_key_values"] result["past_key_values"] = upd_pkv.to_legacy_cache() return result model.forward = types.MethodType(forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward if hasattr(self._model, "model") and hasattr(self._model.model, "_orig_update_causual_mask"): self._model.model._update_causal_mask = self._model.model._orig_update_causual_mask class Idefics3ImageEmbeddingsModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-SmolVLM-2/src/transformers/models/idefics3/modeling_idefics3.py#L999-L1005 def get_image_features(self, pixel_values, patch_attention_mask, patch_position_ids): image_hidden_states = self.vision_model( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, patch_position_ids=patch_position_ids, ).last_hidden_state # Modality projection & resampling image_hidden_states = self.connector(image_hidden_states) return image_hidden_states model.__orig_forward = model.forward model.forward = types.MethodType(get_image_features, model) super().__init__(config, model, model_kwargs) def __enter__(self): # The difference from original code is only in getting patch_position_ids as input and propogation it into embeddings instead of calculation inside based on patch_attention_mask # method for calculation position_ids is not pytorch tracing friendly due to cycle over batch size. def transformer_forward( self, pixel_values, patch_attention_mask: Optional[torch.BoolTensor] = None, patch_position_ids: Optional[torch.IntTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import BaseModelOutput output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = pixel_values.size(0) if patch_attention_mask is None: patch_size = self.patch_size patch_attention_mask = torch.ones( ( batch_size, pixel_values.size(2) // patch_size, pixel_values.size(3) // patch_size, ) ) patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device) hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, patch_position_ids=patch_position_ids, ) patch_attention_mask = patch_attention_mask.view(batch_size, -1) # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): patch_attention_mask = None elif not self._use_flash_attention_2: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=patch_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) if not return_dict: return (last_hidden_state,) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=last_hidden_state, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def embeddings_forward( self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, patch_position_ids: Optional[torch.IntTensor] = None, ) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) if patch_position_ids is None: max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids else: position_ids = patch_position_ids position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings def attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, scale=self.scale, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None self._model.vision_model._orig_forward = self._model.vision_model.forward self._model.vision_model.forward = types.MethodType(transformer_forward, self._model.vision_model) self._model.vision_model.embeddings._orig_forward = self._model.vision_model.embeddings.forward self._model.vision_model.embeddings.forward = types.MethodType( embeddings_forward, self._model.vision_model.embeddings ) for layer in self._model.vision_model.encoder.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(attn_forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward self._model.vision_model.forward = self._model.vision_model._orig_forward self._model.vision_model.embeddings.forward = self._model.vision_model.embeddings._orig_forward for layer in self._model.vision_model.encoder.layers: layer.self_attn.forward = layer.self_attn._orig_forward # Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/decoder_models.py#L367 def _blenderbot_attn_forward_legacy( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions or layer_head_mask is not None: return self._orig_forward( hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions ) """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning if is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) query_states = self._shape(query_states, tgt_len, bsz) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=False, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value # Adopted from https://github.com/huggingface/transformers/blob/v4.52.3/src/transformers/models/blenderbot/modeling_blenderbot.py#L156 def _blenderbot_attn_forward_new( self, hidden_states: torch.Tensor, key_value_states=None, past_key_value=None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: from transformers.cache_utils import EncoderDecoderCache """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder if output_attentions or layer_head_mask is not None: return self._orig_forward( hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions, cache_position, ) is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache else: curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.key_cache[self.layer_idx] value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz, self.num_heads, -1, self.head_dim) # difference with original, removed query_states = query_states.reshape(*proj_shape) * self.scale as scale is part of SDPA query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) # Difference with original, use SDPA instead of eager attention attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=False, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value if is_transformers_version(">=", "4.52"): _blenderbot_attn_forward = _blenderbot_attn_forward_new else: _blenderbot_attn_forward = _blenderbot_attn_forward_legacy def modulewise_patch(model, module_cls, patch_forward): for _, module in model.named_children(): if isinstance(module, module_cls): module._orig_forward = module.forward module.forward = types.MethodType(patch_forward, module) return else: if len(list(module.children())) > 0: modulewise_patch(module, module_cls, patch_forward) def modulewise_unpatch(model, module_cls): for _, module in model.named_children(): if isinstance(module, module_cls): if hasattr(module, "_orig_forward"): module.forward = module._orig_forward else: if len(list(module.children())) > 0: modulewise_unpatch(module, module_cls) class BlenderbotModelPatcher(Seq2SeqModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.49.0"): from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention modulewise_patch(self._model, BlenderbotAttention, _blenderbot_attn_forward) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if is_transformers_version(">=", "4.49.0"): from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention modulewise_unpatch(self._model, BlenderbotAttention) class BlenderbotSmallModelPatcher(Seq2SeqModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.49.0"): from transformers.models.blenderbot_small.modeling_blenderbot_small import BlenderbotSmallAttention modulewise_patch(self._model, BlenderbotSmallAttention, _blenderbot_attn_forward) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if is_transformers_version(">=", "4.49.0"): from transformers.models.blenderbot_small.modeling_blenderbot_small import BlenderbotSmallAttention modulewise_unpatch(self._model, BlenderbotSmallAttention) class BlenderbotStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, BlenderbotModelPatcher): pass class BlenderbotSmallStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, BlenderbotSmallModelPatcher): pass class PegasusModelPatcher(Seq2SeqModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.49.0"): from transformers.models.pegasus.modeling_pegasus import PegasusAttention modulewise_patch(self._model, PegasusAttention, _blenderbot_attn_forward) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if is_transformers_version(">=", "4.49.0"): from transformers.models.pegasus.modeling_pegasus import PegasusAttention modulewise_unpatch(self._model, PegasusAttention) # Copied from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L596 # No modifications, transformers>=4.52.0 this method realization breaks tracing def _qwen2moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class Qwen2MoEPatcher(UpdateCausalMaskModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.52.0"): from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock modulewise_patch(self._model, Qwen2MoeSparseMoeBlock, _qwen2moe_sparse_block_forward) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if is_transformers_version(">=", "4.52.0"): from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock modulewise_unpatch(self._model, Qwen2MoeSparseMoeBlock) class PegasusStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, PegasusModelPatcher): pass class MarianModelPatcher(Seq2SeqModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.49.0"): from transformers.models.marian.modeling_marian import MarianAttention modulewise_patch(self._model, MarianAttention, _blenderbot_attn_forward) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if is_transformers_version(">=", "4.49.0"): from transformers.models.marian.modeling_marian import MarianAttention modulewise_unpatch(self._model, MarianAttention) class MarianStatefulSeq2SeqDecoderPatcher(StatefulSeq2SeqDecoderPatcher, MarianModelPatcher): pass # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L698 # this is a patch to avoid PyTorch FE issue # with the same tensor names on input and intermediate tensor for speaker_embeddings def speecht5_decoder_prenet_forward( self, input_values: torch.Tensor, speaker_embeddings: Optional[torch.Tensor] = None, ): inputs_embeds = input_values for layer in self.layers: inputs_embeds = torch.nn.functional.relu(layer(inputs_embeds)) inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout) inputs_embeds = self.final_layer(inputs_embeds) inputs_embeds = self.encode_positions(inputs_embeds) if speaker_embeddings is not None: # this is a patch to avoid for PyTorch FE issue!!! # with the same tensor names on input and intermediate tensor in a model speaker_embeddings_norm = torch.nn.functional.normalize(speaker_embeddings) speaker_embeddings_unsqueeze = speaker_embeddings_norm.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings_unsqueeze], dim=-1) inputs_embeds = torch.nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) return inputs_embeds # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L993 # this is a patch to avoid CPU plugin issue that is happened on 16-th iteration of token generation # values computed by self-attention attn_output = torch.bmm(attn_probs, value_states) in a decoder gets incorrect def speecht5_attention_forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, serialize: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) # relative attention bias if position_bias is not None: reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1) rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) rel_pos_bias = rel_pos_bias.transpose(0, 1).view( bsz * self.num_heads, position_bias.size(0), position_bias.size(1) ) attn_weights += rel_pos_bias if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # this is a patch to avoid CPU plugin issue!!! # issue is happened on 16-th iteration of token generation # since 16-th iteration of token generation, values computed by self-attention in a decoder gets incorrect eps = 1e-30 attn_output = torch.bmm(attn_probs + eps, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L1175 # this is a patch for a model to avoid incorrect tracing # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple are computed using encoder_hidden_states def speecht5_decoder_layer_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, serialize: bool = False, ): residual = hidden_states # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, serialize=serialize, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states # this is a patch for a model to avoid incorrect tracing!!! # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # are computed using encoder_hidden_states if past_key_value is not None and len(past_key_value) > 3: cross_attn_past_key_value = past_key_value[-2:] else: cross_attn_past_key_value = None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # Fully Connected hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs class OVSpeechT5ModelPatcher(ModelPatcher): def __enter__(self): if self.real_config._behavior != "vocoder": setattr(self._model, self.orig_forward_name, self.patched_forward) if self.real_config._behavior == "decoder": self._model.speecht5.decoder.prenet.__orig_forward = self._model.speecht5.decoder.prenet.forward self._model.speecht5.decoder.prenet.forward = types.MethodType( speecht5_decoder_prenet_forward, self._model.speecht5.decoder.prenet ) for layer in self._model.speecht5.decoder.wrapped_decoder.layers: layer.__orig_forward = layer.forward layer.forward = types.MethodType(speecht5_decoder_layer_forward, layer) layer.self_attn.__orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(speecht5_attention_forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): if self.real_config._behavior != "vocoder": setattr(self._model, self.orig_forward_name, self.orig_forward) if self.real_config._behavior == "decoder": self._model.speecht5.decoder.prenet.forward = types.MethodType( self._model.speecht5.decoder.prenet.__orig_forward, self._model.speecht5.decoder.prenet ) for layer in self._model.speecht5.decoder.wrapped_decoder.layers: layer.forward = types.MethodType(layer.__orig_forward, layer) layer.self_attn.forward = types.MethodType(layer.self_attn.__orig_forward, layer.self_attn) def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): super().__init__(config, model, model_kwargs) def patched_encoder_forward( input_ids: torch.FloatTensor = None, ): encoder_attention_mask = torch.ones_like(input_ids) hidden_states = self._model.prenet(input_ids) encoder_out = self._model.wrapped_encoder( hidden_states=hidden_states, attention_mask=encoder_attention_mask, return_dict=True, ) # downsample encoder attention mask if isinstance(model, SpeechT5EncoderWithSpeechPrenet): encoder_attention_mask = model.prenet._get_feature_vector_attention_mask( encoder_out[0].shape[1], encoder_attention_mask ) result = { "encoder_outputs": encoder_out.last_hidden_state, "encoder_attention_mask": encoder_attention_mask, } return result def patched_decoder_forward( inputs_embeds=None, speaker_embeddings=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, ): return_legacy_cache = False if past_key_values is not None: only_self_cache = [cache_item[:2] for cache_item in past_key_values] past_key_values = only_self_cache return_legacy_cache = True output_sequence = inputs_embeds output_cross_attentions = False bsz = output_sequence.size(0) # Run the decoder prenet on the entire output sequence. decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) # Run the decoder layers on the last element of the prenet output. decoder_out = model.speecht5.decoder.wrapped_decoder( hidden_states=decoder_hidden_states[:, -1:], attention_mask=None, encoder_hidden_states=encoder_hidden_states[0], encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=True, output_attentions=output_cross_attentions, return_dict=True, ) # if output_cross_attentions: # cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) last_decoder_output = decoder_out.last_hidden_state.squeeze(1) # Predict the new mel spectrum for this step in the sequence. spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) # Extend the output sequence with the new mel spectrum. new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins) output_sequence_out = torch.cat((output_sequence, new_spectrogram), dim=1) # Predict the probability that this is the stop token. prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) if return_legacy_cache: only_self_cache = [cache_item[:2] for cache_item in decoder_out.past_key_values] past_key_values = only_self_cache result = { "output_sequence_out": output_sequence_out, "spectrum": spectrum, "prob": prob, "past_key_values": past_key_values, } return result def patched_postnet_forward(raw_spectrogram: torch.FloatTensor): raw_spectrogram = raw_spectrogram.transpose(0, 1).flatten(1, 2) spectrogram = model.speech_decoder_postnet.postnet(raw_spectrogram) result = {"postnet_spectrogram": spectrogram} return result def patched_vocoder_forward(spectrogram: torch.FloatTensor): waveform = model(spectrogram) result = {"waveform": waveform} return result if self.real_config._behavior == "encoder": self.patched_forward = patched_encoder_forward elif self.real_config._behavior == "decoder": self.patched_forward = patched_decoder_forward elif self.real_config._behavior == "postnet": self.patched_forward = patched_postnet_forward elif self.real_config._behavior == "vocoder": self.patched_forward = patched_vocoder_forward else: raise ValueError("Unknown ") self.orig_forward = self.patched_forward class Phi4MMLanguageModelPatcher(DecoderModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): if hasattr(model.config, "vision_lora") and model.config.vision_lora is not None: model.set_lora_adapter("vision") if hasattr(model.config, "speech_lora") and model.config.speech_lora is not None: model.set_lora_adapter("speech") # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L2156-L2178 # moved audio and vision features processing outside model def lm_forward(self, inputs_embeds, attention_mask, position_ids, past_key_values, use_cache=True): from transformers.cache_utils import DynamicCache pkv = DynamicCache.from_legacy_cache(past_key_values) outputs = self.model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, past_key_values=pkv, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states) return (logits, outputs.past_key_values.to_legacy_cache()) model.__orig_forward = model.forward model.forward = types.MethodType(lm_forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward class Phi4MMAudioForwardEmbeddingsPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1121 def forward(self, audio_input): if hasattr(self, "_forward_embeddings_code"): audio_input, masks = self._forward_embeddings_core(audio_input, None) else: audio_input, masks = self.embed(audio_input, None) return audio_input model.__orig_forward = model.forward model.forward = types.MethodType(forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward class Phi4MMAudioEncoderPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1201-L1212 def forward(self, audio_feature, audio_mask): if hasattr(self, "init_relative_attention_bias"): relative_attention_bias = self.init_relative_attention_bias(audio_feature) _simplified_path = self.extra_layer_output_idx == -1 and relative_attention_bias is None if _simplified_path: audio_feature, *_ = self.encoders(audio_feature, None, None, audio_mask) else: for layer in self.encoders: audio_feature, _, _, _ = layer( audio_feature, None, None, audio_mask, relative_attention_bias=relative_attention_bias, ) else: relative_attention_bias = self.relative_attention_bias_layer(audio_feature) attention_mask = audio_mask.unsqueeze(1) + relative_attention_bias for layer in self.encoders: audio_feature = layer(audio_feature, attention_mask) return audio_feature model.__orig_forward = model.forward model.forward = types.MethodType(forward, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward class Phi4MMVisionEmbeddingsPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): def get_img_features_legacy( self, pixel_values: torch.FloatTensor, patch_attention_mask=None, patch_position_ids=None ) -> torch.FloatTensor: LAYER_IDX = self.layer_idx TYPE_FEATURE = self.type_feature if self.freeze_img_processor: with torch.no_grad(): if patch_attention_mask is not None: img_processor_output = self.img_processor( pixel_values, output_hidden_states=True, patch_attention_mask=patch_attention_mask, position_ids=patch_position_ids, ) else: img_processor_output = self.img_processor( pixel_values, output_hidden_states=True, position_ids=patch_position_ids ) img_feature = img_processor_output.hidden_states[LAYER_IDX] else: if patch_attention_mask is not None: img_processor_output = self.img_processor( pixel_values, output_hidden_states=True, patch_attention_mask=patch_attention_mask, position_ids=patch_position_ids, ) else: img_processor_output = self.img_processor( pixel_values, output_hidden_states=True, position_ids=patch_position_ids ) img_feature = img_processor_output.hidden_states[LAYER_IDX] if TYPE_FEATURE == "patch": patch_feature = img_feature if self.image_token_compression is not None: # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) if getattr(self, "img_processor_padding", None) is not None: patch_feature = self.img_processor_padding(patch_feature) patch_feature = self.image_token_compression(patch_feature) # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view( -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) ) elif getattr(self, "img_processor_padding", None) is not None: width = int(math.sqrt(patch_feature.size(1))) patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) patch_feature = self.img_processor_padding(patch_feature) # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view( -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) ) return patch_feature if TYPE_FEATURE == "cls_patch": if self.image_token_compression is not None: # reshape to 2D tensor patch_feature = img_feature[:, 1:] cls_feature = img_feature[:, 0] width = math.sqrt(patch_feature.size(1)) patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) patch_feature = self.image_token_compression(patch_feature) patch_feature = patch_feature.view(-1, patch_feature.size(-2) * patch_feature.size(-1)) img_feature = torch.cat([cls_feature, patch_feature], dim=1) return img_feature # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L649 # added possibility to provide patch_position_ids def get_img_features( self, pixel_values: torch.FloatTensor, patch_attention_mask=None, patch_position_ids=None ): img_processor_output = self.img_processor( pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True, position_ids=patch_position_ids, ) img_feature = img_processor_output.hidden_states[self.layer_idx] patch_feature = img_feature # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) if getattr(self, "img_processor_padding", None) is not None: patch_feature = self.img_processor_padding(patch_feature) patch_feature = self.image_token_compression(patch_feature) # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view( -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) ) return patch_feature model.__orig_forward = model.forward model.forward = types.MethodType( get_img_features_legacy if hasattr(model, "type_feature") else get_img_features, model ) super().__init__(config, model, model_kwargs) def __enter__(self): super().__enter__() # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L563 # added possibility calculate position_ids outside def transformer_fwd( self, pixel_values, patch_attention_mask: Optional[torch.BoolTensor] = None, position_ids: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = pixel_values.size(0) if patch_attention_mask is None: patch_attention_mask = torch.ones( size=( batch_size, pixel_values.size(2) // self.config.patch_size, pixel_values.size(3) // self.config.patch_size, ), dtype=torch.bool, device=pixel_values.device, ) hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, position_ids=position_ids ) patch_attention_mask = patch_attention_mask.view(batch_size, -1) # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): attention_mask = None else: attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = self.head( hidden_state=last_hidden_state, attention_mask=patch_attention_mask, ) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L76 # used SDPA instead of eager attention def attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L488 # moved position_ids calculation outside of model def embd_forward( self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, position_ids: torch.FloatTensor = None, ) -> torch.Tensor: batch_size = pixel_values.size(0) patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) if position_ids is None: max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full( size=( batch_size, max_nb_patches_h * max_nb_patches_w, ), fill_value=0, ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings if ( getattr(self._model.img_processor.encoder.layers[0].self_attn.config, "_attn_implementation", "eager") != "sdpa" ): for layer in self._model.img_processor.encoder.layers: layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(attn_forward, layer.self_attn) self._model.img_processor._orig_forward = self._model.img_processor.forward self._model.img_processor.forward = types.MethodType(transformer_fwd, self._model.img_processor) self._model.img_processor.embeddings._orig_forward = self._model.img_processor.embeddings.forward self._model.img_processor.embeddings.forward = types.MethodType( embd_forward, self._model.img_processor.embeddings ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward for layer in self._model.img_processor.encoder.layers: if hasattr(layer.self_attn, "_orig_frward"): layer.self_attn.forward = layer.self_attn._orig_forward self._model.img_processor.forward = self._model.img_processor._orig_forward self._model.img_processor.embeddings.forward = self._model.img_processor.embeddings._orig_forward class Llama4ImageEmbeddingsModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward # Adopted from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L1732-L1741 def get_image_embeddings(self, pixel_values): image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layer=self.config.vision_config.vision_feature_layer, vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy, ) vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.multi_modal_projector(vision_flat) return projected_vision_flat model.forward = types.MethodType(get_image_embeddings, model) super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward # modified from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L229 # use real cos / sin instead of complex def llama4_rope_forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # use real cos / sin instead of complex # Modified from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/llama4/modeling_llama4.py#L247 # Based on https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py#L292 # Native DeepSeek apply rotary emb works in the same way like llama4 apply rotary emb def llama4_apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: from transformers.models.llama.modeling_llama import rotate_half xq_ = xq.float() xk_ = xk.float() cos = cos.unsqueeze(2) sin = sin.unsqueeze(2) b, h, s, d = xq_.shape xq_ = xq_.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) b, h, s, d = xk_.shape xk_ = xk_.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) q_embed = (xq_ * cos) + (rotate_half(xq_) * sin) k_embed = (xk_ * cos) + (rotate_half(xk_) * sin) return q_embed.type_as(xq), k_embed.type_as(xk) # https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L329 # use real cos / sin instead of complex def llama4_attn_forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value=None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: from transformers.models.llama4.modeling_llama4 import ALL_ATTENTION_FUNCTIONS, eager_attention_forward input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if self.use_rope: # the 16E model skips rope for long context on certain layers cos, sin = position_embeddings[0], position_embeddings[1] query_states, key_states = llama4_apply_rotary_emb( query_states, key_states, cos.to(query_states.device), sin.to(query_states.device) ) if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm query_states = self.qk_norm(query_states) key_states = self.qk_norm(key_states) # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers if self.attn_temperature_tuning and not self.use_rope: attn_scales = ( torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 ) attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 query_states = (query_states * attn_scales).to(query_states.dtype) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): attention_interface = eager_attention_forward else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights # modified from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L157 # due to openvino transformations issue removed routed_out.view(-1, hidden_dim) in scatter_add_ def llama4_moe_forward(self, hidden_states): batch, seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = self.router(hidden_states).transpose(0, 1) tokens_per_expert = batch * seq_len router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1) router_scores = ( torch.full_like(router_logits.transpose(0, 1), float("-inf")) .scatter_(1, router_indices, router_top_value) .transpose(0, 1) ) # We do this to make sure we have -inf for non topK tokens before going through the ! # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this! router_indices = ( torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim) routed_in = torch.gather( input=hidden_states, dim=0, index=router_indices, ).to(hidden_states.device) # we gather inputs corresponding to each expert based on the router indices routed_in = routed_in * router_scores.reshape(-1, 1) routed_out = self.experts(routed_in) out = self.shared_expert(hidden_states) # now that we finished expert computation -> we scatter add because we gathered previously # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # this scales a lot better if you do EP! out.scatter_add_(dim=0, index=router_indices, src=routed_out) return out, router_scores class Llama4TextModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward self._model.model.rotary_emb.forward = types.MethodType(llama4_rope_forward, self._model.model.rotary_emb) for layer in self._model.model.layers[: self._model.model.config.num_hidden_layers]: if layer.is_moe_layer: layer.feed_forward._orig_forward = layer.feed_forward.forward layer.feed_forward.forward = types.MethodType(llama4_moe_forward, layer.feed_forward) layer.self_attn._orig_forward = layer.self_attn.forward layer.self_attn.forward = types.MethodType(llama4_attn_forward, layer.self_attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward for layer in self._model.model.layers[: self._model.model.config.num_hidden_layers]: if layer.is_moe_layer: layer.feed_forward.forward = layer.feed_forward._orig_forward layer.self_attn.forward = layer.self_attn._orig_forward