optimum/exporters/ipex/modeling_utils.py (1,107 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from optimum.intel.utils.import_utils import is_ipex_version, is_torch_version
from optimum.intel.utils.modeling_utils import _setattr_from_module
from .cache_utils import IPEXPagedCache
logger = logging.getLogger(__name__)
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.6.0"
_accelerate_added_attributes = ["to", "xpu"]
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
logger.warning(
f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model."
)
else:
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.llm.functional import varlen_attention
from intel_extension_for_pytorch.llm.modules import (
Linear2SiluMul,
LinearAdd,
LinearAddAdd,
LinearGelu,
LinearNewGelu,
PagedAttention,
RMSNorm,
RotaryEmbedding,
)
device_type = "xpu" if ipex._C._has_xpu() else "cpu"
# Assign device type earlier to void recompile in ipex.
PagedAttention.runtime_ops.device_type = device_type
RMSNorm.runtime_ops.device_type = device_type
RotaryEmbedding.runtime_ops.device_type = device_type
# Adapted from https://github.com/huggingface/accelerate/blob/v1.2.1/src/accelerate/hooks.py#L183
def _remove_hooks_for_ipex(module, recurse):
if hasattr(module, "_hf_hook"):
module._hf_hook.detach_hook(module)
delattr(module, "_hf_hook")
if hasattr(module, "_old_forward"):
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = module.__class__.forward.__get__(module)
else:
module.forward = module.__class__.forward.__get__(module)
delattr(module, "_old_forward")
# Remove accelerate added warning hooks from dispatch_model
for attr in _accelerate_added_attributes:
module.__dict__.pop(attr, None)
if recurse:
for child in module.children():
_remove_hooks_for_ipex(child, recurse)
return module
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
def _ipex_rms_layer_norm_forward(self, hidden_states):
return RMSNorm.apply_function(hidden_states, self.weight, self.variance_epsilon)
# Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/falcon/modeling_falcon.py#L1161
# For passing kwargs, we can remove it when falcon model support passing kwargs to self.transformer.
def _falcon_for_causal_lm_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = transformer_outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
lm_logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/gpt2/modeling_gpt2.py#L1036
# For passing kwargs, we can remove it when gpt2 model support passing kwargs to self.transformer.
def _gpt2_lm_head_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Flatten the tokens
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L918
def _llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is not None and not isinstance(past_key_values, IPEXPagedCache):
raise ValueError("only support IPEXPagedCache input now")
max_input_lens = self.config.max_input_lens
past_key_values_length = max_input_lens - seq_length
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(past_key_values_length, max_input_lens, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
position_embeddings = self.rotary_emb(hidden_states, position_ids)
index = kwargs.pop("index", None)
cos = position_embeddings[0]
sin = position_embeddings[1]
hidden_states_copy = hidden_states
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index)
cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index)
sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index)
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
past_key_values_length=past_key_values_length,
max_input_lens=self.config.max_input_lens,
query_max_len=seq_length,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if hidden_states.shape[0] != batch_size * seq_length:
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
hidden_states = hidden_states_copy
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/falcon/modeling_falcon.py#L945
def _falcon_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
max_input_lens = self.config.max_input_lens
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = max_input_lens - seq_length
device = input_ids.device if input_ids is not None else inputs_embeds.device
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
index = kwargs.pop("index", None)
cos = position_embeddings[0]
sin = position_embeddings[1]
hidden_states_copy = hidden_states
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index)
cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index)
sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index)
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
next_decoder_cache = None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=None,
cache_position=cache_position,
position_embeddings=position_embeddings,
past_key_values_length=past_key_values_length,
max_input_lens=self.config.max_input_lens,
query_max_len=seq_length,
**kwargs,
)
hidden_states = outputs[0]
if use_cache is True:
next_decoder_cache = outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if hidden_states.shape[0] != batch_size * seq_length:
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
hidden_states = hidden_states_copy
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _gpt2_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
max_input_lens = self.config.max_input_lens
seq_length = input_ids.shape[-1]
past_key_values_length = max_input_lens - seq_length
if position_ids is None:
position_ids = torch.arange(
past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
batch_size, seq_length, _ = inputs_embeds.shape
position_embeddings = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeddings
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
index = kwargs.pop("index", None)
hidden_states_copy = hidden_states
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index)
if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
presents = None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
past_key_values_length=past_key_values_length,
max_input_lens=self.config.max_input_lens,
query_max_len=seq_length,
**kwargs,
)
hidden_states = outputs[0]
if use_cache is True:
presents = outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
hidden_states = self.ln_f(hidden_states)
if hidden_states.shape[0] != batch_size * seq_length:
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
hidden_states = hidden_states_copy
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499
def _qwen2_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = inputs_embeds.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# avoid multi inputs
kwargs.pop("max_input_lens", None)
max_input_lens = self.config.max_input_lens
past_key_values_length = max_input_lens - seq_length
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
index = kwargs.pop("index", None)
cos = position_embeddings[0]
sin = position_embeddings[1]
hidden_states_copy = hidden_states
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index)
cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index)
sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index)
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
if past_key_values is None:
attention_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
past_key_values_length=past_key_values_length,
max_input_lens=max_input_lens,
query_max_len=seq_length,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if hidden_states.shape[0] != batch_size * seq_length:
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
hidden_states = hidden_states_copy
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
# Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral/modeling_mistral.py#L459
def _mistral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = inputs_embeds.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# avoid multi inputs
kwargs.pop("max_input_lens", None)
max_input_lens = self.config.max_input_lens
past_key_values_length = max_input_lens - seq_length
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=device
)
if position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
index = kwargs.pop("index", None)
cos = position_embeddings[0]
sin = position_embeddings[1]
hidden_states_copy = hidden_states
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index)
cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index)
sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index)
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
# TODO: remove this WA after IPEX 2.7
if device.type == "xpu":
cos = cos.reshape(-1, cos.shape[-1])
sin = sin.reshape(-1, sin.shape[-1])
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
if past_key_values is None:
attention_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
past_key_values_length=past_key_values_length,
max_input_lens=max_input_lens,
query_max_len=seq_length,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if hidden_states.shape[0] != batch_size * seq_length:
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
hidden_states = hidden_states_copy
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
class _IPEXAttention(nn.Module):
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = device
self.num_key_value_heads = config.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.num_groups = self.num_attention_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
).repeat_interleave(self.num_groups)
self.use_sdpa = False
def qkv_gemm(self, hidden_states):
raise NotImplementedError("Need to implement in specific model class")
def rope(self, query, key, **kwargs):
position_embeddings = kwargs.pop("position_embeddings", None)
RotaryEmbedding.apply_function(
query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True
)
return query, key
def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
return attn_output
# Maybe removed after torch 2.6 released
def has_flash_attn(self):
if self.module_device.type == "cpu":
return is_torch_version(">", "2.4.99")
elif self.module_device.type == "xpu":
return is_torch_version(">", "2.5.99")
def attention_interface(
self,
query,
key_cache,
value_cache,
key,
value,
past_key_value,
attention_mask,
input_lens,
past_key_values_length,
seq_len_tensor,
query_len_tensor,
max_input_lens,
query_max_len,
):
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query.reshape(input_lens.shape[0], max_input_lens, -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], max_input_lens, -1, key.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
value.reshape(input_lens.shape[0], max_input_lens, -1, value.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn():
attn_output = torch.empty_like(query)
PagedAttention.flash_attn_varlen_func(
attn_output,
query.contiguous() if query.device.type == "xpu" else query,
key_cache,
value_cache,
query_len_tensor,
seq_len_tensor,
query_max_len,
max_input_lens,
1.0 / math.sqrt(self.head_dim),
True,
past_key_value.block_tables,
None,
)
elif past_key_values_length == 0:
# prefill, remove padding
attn_output = torch.empty_like(query)
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
attn_output,
seq_len_tensor,
seq_len_tensor,
max_input_lens,
max_input_lens,
0.0,
1.0 / math.sqrt(self.head_dim),
False,
True,
False,
None,
)
else:
# decode
attn_output = torch.empty_like(query)
PagedAttention.single_query_cached_kv_attention(
attn_output,
query,
key_cache,
value_cache,
self.kv_head_mapping,
1.0 / math.sqrt(self.head_dim),
past_key_value.block_tables,
input_lens,
past_key_value.block_size,
max_input_lens,
None,
)
return attn_output
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[IPEXPagedCache] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if past_key_value is None and kwargs.get("layer_past", None) is not None:
past_key_value = kwargs.pop("layer_past", None)
input_lens = kwargs.pop("input_lens", None)
seq_len_tensor = kwargs.pop("seq_len_tensor", None)
query_len_tensor = kwargs.pop("query_len_tensor", None)
max_input_lens = kwargs.pop("max_input_lens", 0)
query_max_len = kwargs.pop("query_max_len", 0)
past_key_values_length = kwargs.pop("past_key_values_length", 0)
query, key, value = self.qkv_gemm(hidden_states)
query, key = self.rope(query, key, **kwargs)
key_cache, value_cache = None, None
if past_key_value is not None:
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx)
attn_output = self.attention_interface(
query,
key_cache,
value_cache,
key,
value,
past_key_value,
attention_mask,
input_lens,
past_key_values_length,
seq_len_tensor,
query_len_tensor,
max_input_lens,
query_max_len,
)
attn_output = self.postprocess_attention_output(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, past_key_value, attn_weights
class _IPEXLlamaAttention(_IPEXAttention):
def __init__(self, module, device, config) -> None:
super().__init__(module, device, config)
if getattr(config, "quantization_config", None) is None:
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias is not None]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0).contiguous()
self.concat_qkv.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
if not config.compile and module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
def qkv_gemm(self, hidden_states):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
if hasattr(self, "concat_qkv"):
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(hidden_shape)
key = qkv_out[:, self.q_slice : self.k_slice].view(hidden_shape)
value = qkv_out[:, self.k_slice :].view(hidden_shape)
else:
query = self.q_proj(hidden_states).view(hidden_shape)
key = self.k_proj(hidden_states).view(hidden_shape)
value = self.v_proj(hidden_states).view(hidden_shape)
return query, key, value
class _IPEXFalconAttention(_IPEXAttention):
def __init__(self, module, device, config):
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, device, config)
self.q_slice = self.head_dim * config.num_kv_heads
self.k_slice = self.q_slice + self.head_dim
self.v_slice = self.k_slice + self.head_dim
def qkv_gemm(self, hidden_states):
qkv_out = self.query_key_value(hidden_states)
if self.new_decoder_architecture:
qkv_out = qkv_out.view(
qkv_out.shape[0], -1, self.num_attention_heads // self.num_kv_heads + 2, self.head_dim
)
query = qkv_out[:, :, :-2, :].flatten(1, 2)
key = qkv_out[:, :, [-2], :].flatten(1, 2)
value = qkv_out[:, :, [-1], :].flatten(1, 2)
else:
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
return query, key, value
class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, device, config) -> None:
super().__init__(module, device, config)
_setattr_from_module(self, module)
if not config.compile and getattr(config, "quantization_config", None) is None:
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)
def qkv_gemm(self, hidden_states):
if hasattr(self, "c_attn_linear"):
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_attention_heads, self.head_dim)
key = key.view(-1, self.num_attention_heads, self.head_dim)
value = value.view(-1, self.num_attention_heads, self.head_dim)
return query, key, value
def rope(self, query, key, *args, **kwargs):
return query, key
def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
if not hasattr(self, "linear_add"):
attn_output = self.c_proj(attn_output)
return attn_output
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186
class _IPEXLlamaMLP(nn.Module):
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = device
if not config.compile and getattr(config, "quantization_config", None) is None:
# LinearAllreduce cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
if isinstance(self.act_fn, nn.SiLU):
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.down_proj(mlp_gate)
hidden_states = residual + hidden_states
else:
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
hidden_states = residual + hidden_states
return hidden_states
class _IPEXFalconMLP(nn.Module):
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = device
if not config.compile and getattr(config, "quantization_config", None) is None:
# LinearAllreduce cannot use fused op LinearAdd
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
def forward(
self,
hidden_states: torch.Tensor,
attention_output: torch.Tensor = None,
residual: torch.Tensor = None,
**kwargs,
):
if hasattr(self, "linear_gelu"):
mlp_hidden_states = self.linear_gelu(hidden_states)
else:
mlp_hidden_states = self.act(self.dense_h_to_4h(hidden_states))
if hasattr(self, "linear_add_add"):
output = self.linear_add_add(mlp_hidden_states, attention_output, residual)
else:
mlp_output = self.dense_4h_to_h(mlp_hidden_states)
output = mlp_output + attention_output + residual
return output
class _IPEXGPT2MLP(nn.Module):
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = device
if not config.compile and getattr(config, "quantization_config", None) is None:
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
self.c_fc_linear.bias = self.c_fc.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
if hasattr(self, "linear_new_gelu"):
hidden_states = self.linear_new_gelu(hidden_states)
else:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
if not hasattr(self, "linear_add"):
hidden_states = self.c_proj(hidden_states)
return hidden_states
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.self_attn = _IPEXLlamaAttention(module.self_attn, device, config)
self.mlp = _IPEXLlamaMLP(module.mlp, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)
def forward(self, hidden_states: torch.Tensor, **kwargs):
# Please see the original model's forward to check the parameter
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, present, attn_weights = self.self_attn(hidden_states=hidden_states, **kwargs)
if hasattr(self.self_attn, "mha_linear_add"):
hidden_states = self.self_attn.mha_linear_add(hidden_states, residual)
else:
hidden_states = self.self_attn.o_proj(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, residual, **kwargs)
outputs = (hidden_states,)
if kwargs.get("output_attentions", False):
outputs += (attn_weights,)
if kwargs.get("use_cache", False):
outputs += (present,)
return outputs
class _IPEXFalconDecoderLayer(nn.Module):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.self_attention = _IPEXFalconAttention(module.self_attention, device, config)
self.mlp = _IPEXFalconMLP(module.mlp, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)
def forward(self, hidden_states: torch.Tensor, **kwargs):
# Please see the original model's forward to check the parameter
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_output, present, attn_weights = self.self_attention(hidden_states, **kwargs)
attn_output = self.self_attention.dense(attn_output)
hidden_states = self.mlp(hidden_states, attn_output, residual)
outputs = (hidden_states,)
if kwargs.get("output_attentions", False):
outputs += (attn_weights,)
if kwargs.get("use_cache", False):
outputs += (present,)
return outputs
class _IPEXGPT2Block(nn.Module):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.attn = _IPEXGPT2Attention(module.attn, device, config)
self.mlp = _IPEXGPT2MLP(module.mlp, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
if hasattr(self.attn, "linear_add"):
hidden_states = self.attn.linear_add(attn_output, residual)
else:
hidden_states = attn_output + residual
if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
if hasattr(self.mlp, "linear_add"):
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
else:
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions)
# Currently can just apply llama decoder layer.
class _IPEXQwen2DecoderLayer(_IPEXLlamaDecoderLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class _IPEXMistralDecoderLayer(_IPEXLlamaDecoderLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
class _IPEXIntermediate(nn.Module):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.module_device = device
if not config.compile and getattr(config, "quantization_config", None) is None:
self.linear_gelu = LinearGelu(module.dense)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if hasattr(self, "linear_gelu"):
hidden_states = self.linear_gelu(hidden_states)
else:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states