optimum/exporters/ipex/modeling_utils.py (1,107 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math from typing import List, Optional, Tuple, Union import torch from torch import nn from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) from optimum.intel.utils.import_utils import is_ipex_version, is_torch_version from optimum.intel.utils.modeling_utils import _setattr_from_module from .cache_utils import IPEXPagedCache logger = logging.getLogger(__name__) _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.6.0" _accelerate_added_attributes = ["to", "xpu"] if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): logger.warning( f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model." ) else: import intel_extension_for_pytorch as ipex from intel_extension_for_pytorch.llm.functional import varlen_attention from intel_extension_for_pytorch.llm.modules import ( Linear2SiluMul, LinearAdd, LinearAddAdd, LinearGelu, LinearNewGelu, PagedAttention, RMSNorm, RotaryEmbedding, ) device_type = "xpu" if ipex._C._has_xpu() else "cpu" # Assign device type earlier to void recompile in ipex. PagedAttention.runtime_ops.device_type = device_type RMSNorm.runtime_ops.device_type = device_type RotaryEmbedding.runtime_ops.device_type = device_type # Adapted from https://github.com/huggingface/accelerate/blob/v1.2.1/src/accelerate/hooks.py#L183 def _remove_hooks_for_ipex(module, recurse): if hasattr(module, "_hf_hook"): module._hf_hook.detach_hook(module) delattr(module, "_hf_hook") if hasattr(module, "_old_forward"): # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 if "GraphModuleImpl" in str(type(module)): module.__class__.forward = module.__class__.forward.__get__(module) else: module.forward = module.__class__.forward.__get__(module) delattr(module, "_old_forward") # Remove accelerate added warning hooks from dispatch_model for attr in _accelerate_added_attributes: module.__dict__.pop(attr, None) if recurse: for child in module.children(): _remove_hooks_for_ipex(child, recurse) return module # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 def _ipex_rms_layer_norm_forward(self, hidden_states): return RMSNorm.apply_function(hidden_states, self.weight, self.variance_epsilon) # Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/falcon/modeling_falcon.py#L1161 # For passing kwargs, we can remove it when falcon model support passing kwargs to self.transformer. def _falcon_for_causal_lm_forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = transformer_outputs[0] slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep lm_logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( lm_logits, labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) # Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/gpt2/modeling_gpt2.py#L1036 # For passing kwargs, we can remove it when gpt2 model support passing kwargs to self.transformer. def _gpt2_lm_head_model_forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) hidden_states = transformer_outputs[0] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.transformer.first_device) hidden_states = hidden_states.to(self.lm_head.weight.device) lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: # Flatten the tokens loss = self.loss_function( lm_logits, labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, ) # Adapted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L918 def _llama_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if past_key_values is not None and not isinstance(past_key_values, IPEXPagedCache): raise ValueError("only support IPEXPagedCache input now") max_input_lens = self.config.max_input_lens past_key_values_length = max_input_lens - seq_length device = input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: position_ids = torch.arange(past_key_values_length, max_input_lens, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None position_embeddings = self.rotary_emb(hidden_states, position_ids) index = kwargs.pop("index", None) cos = position_embeddings[0] sin = position_embeddings[1] hidden_states_copy = hidden_states hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index) cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index) sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index) position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) if past_key_values is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask=attention_mask, input_shape=(input_ids.shape[0], input_ids.shape[-1]), inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, past_key_values_length=past_key_values_length, max_input_lens=self.config.max_input_lens, query_max_len=seq_length, **kwargs, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if hidden_states.shape[0] != batch_size * seq_length: (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states hidden_states = hidden_states_copy hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # Adapted from https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/falcon/modeling_falcon.py#L945 def _falcon_model_forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) max_input_lens = self.config.max_input_lens batch_size, seq_length, _ = inputs_embeds.shape past_key_values_length = max_input_lens - seq_length device = input_ids.device if input_ids is not None else inputs_embeds.device if cache_position is None: cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device) if position_ids is None: position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_attention_heads x N x N # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) index = kwargs.pop("index", None) cos = position_embeddings[0] sin = position_embeddings[1] hidden_states_copy = hidden_states hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index) cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index) sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index) position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) if past_key_values is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask=attention_mask, input_shape=(input_ids.shape[0], input_ids.shape[-1]), inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, layer_past=past_key_values, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=None, cache_position=cache_position, position_embeddings=position_embeddings, past_key_values_length=past_key_values_length, max_input_lens=self.config.max_input_lens, query_max_len=seq_length, **kwargs, ) hidden_states = outputs[0] if use_cache is True: next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) next_cache = next_decoder_cache if use_cache else None if hidden_states.shape[0] != batch_size * seq_length: (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states hidden_states = hidden_states_copy hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) def _gpt2_model_forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) max_input_lens = self.config.max_input_lens seq_length = input_ids.shape[-1] past_key_values_length = max_input_lens - seq_length if position_ids is None: position_ids = torch.arange( past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) batch_size, seq_length, _ = inputs_embeds.shape position_embeddings = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeddings encoder_attention_mask = None head_mask = self.get_head_mask(head_mask, self.config.n_layer) if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) index = kwargs.pop("index", None) hidden_states_copy = hidden_states hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index) if past_key_values is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask=attention_mask, input_shape=(input_ids.shape[0], input_ids.shape[-1]), inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) presents = None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, layer_past=past_key_values, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, past_key_values_length=past_key_values_length, max_input_lens=self.config.max_input_lens, query_max_len=seq_length, **kwargs, ) hidden_states = outputs[0] if use_cache is True: presents = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) hidden_states = self.ln_f(hidden_states) if hidden_states.shape[0] != batch_size * seq_length: (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states hidden_states = hidden_states_copy hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) # Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499 def _qwen2_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = inputs_embeds.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device # avoid multi inputs kwargs.pop("max_input_lens", None) max_input_lens = self.config.max_input_lens past_key_values_length = max_input_lens - seq_length if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) index = kwargs.pop("index", None) cos = position_embeddings[0] sin = position_embeddings[1] hidden_states_copy = hidden_states hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index) cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index) sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index) position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) if past_key_values is None: attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, past_key_values_length=past_key_values_length, max_input_lens=max_input_lens, query_max_len=seq_length, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if hidden_states.shape[0] != batch_size * seq_length: (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states hidden_states = hidden_states_copy hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) output = BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) return output if return_dict else output.to_tuple() # Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral/modeling_mistral.py#L459 def _mistral_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = inputs_embeds.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device # avoid multi inputs kwargs.pop("max_input_lens", None) max_input_lens = self.config.max_input_lens past_key_values_length = max_input_lens - seq_length if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=device ) if position_ids is None: position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) index = kwargs.pop("index", None) cos = position_embeddings[0] sin = position_embeddings[1] hidden_states_copy = hidden_states hidden_states = (hidden_states.view(-1, hidden_states.shape[-1])).index_select(0, index) cos = (cos.reshape(-1, cos.shape[-1])).index_select(0, index) sin = (sin.reshape(-1, sin.shape[-1])).index_select(0, index) position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) # TODO: remove this WA after IPEX 2.7 if device.type == "xpu": cos = cos.reshape(-1, cos.shape[-1]) sin = sin.reshape(-1, sin.shape[-1]) position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) if past_key_values is None: attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, past_key_values_length=past_key_values_length, max_input_lens=max_input_lens, query_max_len=seq_length, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if hidden_states.shape[0] != batch_size * seq_length: (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states hidden_states = hidden_states_copy hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) output = BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) return output if return_dict else output.to_tuple() class _IPEXAttention(nn.Module): def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config self.module_device = device self.num_key_value_heads = config.num_key_value_heads self.num_attention_heads = config.num_attention_heads self.num_groups = self.num_attention_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device ).repeat_interleave(self.num_groups) self.use_sdpa = False def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") def rope(self, query, key, **kwargs): position_embeddings = kwargs.pop("position_embeddings", None) RotaryEmbedding.apply_function( query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True ) return query, key def postprocess_attention_output(self, attn_output): if self.use_sdpa: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output # Maybe removed after torch 2.6 released def has_flash_attn(self): if self.module_device.type == "cpu": return is_torch_version(">", "2.4.99") elif self.module_device.type == "xpu": return is_torch_version(">", "2.5.99") def attention_interface( self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_key_values_length, seq_len_tensor, query_len_tensor, max_input_lens, query_max_len, ): if past_key_value is None: n_rep = query.shape[1] // key.shape[1] attn_output = torch.nn.functional.scaled_dot_product_attention( query.reshape(input_lens.shape[0], max_input_lens, -1, query.shape[-1]).transpose(1, 2), key.reshape(input_lens.shape[0], max_input_lens, -1, key.shape[-1]) .transpose(1, 2) .repeat_interleave(n_rep, 1), value.reshape(input_lens.shape[0], max_input_lens, -1, value.shape[-1]) .transpose(1, 2) .repeat_interleave(n_rep, 1), attn_mask=attention_mask, dropout_p=0.0, is_causal=True, ) self.use_sdpa = True elif self.has_flash_attn(): attn_output = torch.empty_like(query) PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache, value_cache, query_len_tensor, seq_len_tensor, query_max_len, max_input_lens, 1.0 / math.sqrt(self.head_dim), True, past_key_value.block_tables, None, ) elif past_key_values_length == 0: # prefill, remove padding attn_output = torch.empty_like(query) varlen_attention( query.contiguous() if query.device.type == "xpu" else query, key.contiguous() if key.device.type == "xpu" else key, value.contiguous() if value.device.type == "xpu" else value, attn_output, seq_len_tensor, seq_len_tensor, max_input_lens, max_input_lens, 0.0, 1.0 / math.sqrt(self.head_dim), False, True, False, None, ) else: # decode attn_output = torch.empty_like(query) PagedAttention.single_query_cached_kv_attention( attn_output, query, key_cache, value_cache, self.kv_head_mapping, 1.0 / math.sqrt(self.head_dim), past_key_value.block_tables, input_lens, past_key_value.block_size, max_input_lens, None, ) return attn_output def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[IPEXPagedCache] = None, output_attentions: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if past_key_value is None and kwargs.get("layer_past", None) is not None: past_key_value = kwargs.pop("layer_past", None) input_lens = kwargs.pop("input_lens", None) seq_len_tensor = kwargs.pop("seq_len_tensor", None) query_len_tensor = kwargs.pop("query_len_tensor", None) max_input_lens = kwargs.pop("max_input_lens", 0) query_max_len = kwargs.pop("query_max_len", 0) past_key_values_length = kwargs.pop("past_key_values_length", 0) query, key, value = self.qkv_gemm(hidden_states) query, key = self.rope(query, key, **kwargs) key_cache, value_cache = None, None if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx) attn_output = self.attention_interface( query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_key_values_length, seq_len_tensor, query_len_tensor, max_input_lens, query_max_len, ) attn_output = self.postprocess_attention_output(attn_output) if not output_attentions: attn_weights = None return attn_output, past_key_value, attn_weights class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, device, config) -> None: super().__init__(module, device, config) if getattr(config, "quantization_config", None) is None: concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous() bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias is not None] use_bias = bias_list != [] self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) self.concat_qkv.weight = nn.Parameter(concat_weight) if use_bias: concat_bias = torch.concat(bias_list, 0).contiguous() self.concat_qkv.bias = nn.Parameter(concat_bias) self.q_slice = self.q_proj.weight.shape[0] self.k_slice = self.q_slice + self.k_proj.weight.shape[0] self.v_slice = self.k_slice + self.v_proj.weight.shape[0] if not config.compile and module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mha_linear_add = LinearAdd(module.o_proj) def qkv_gemm(self, hidden_states): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) if hasattr(self, "concat_qkv"): qkv_out = self.concat_qkv(hidden_states) query = qkv_out[:, : self.q_slice].view(hidden_shape) key = qkv_out[:, self.q_slice : self.k_slice].view(hidden_shape) value = qkv_out[:, self.k_slice :].view(hidden_shape) else: query = self.q_proj(hidden_states).view(hidden_shape) key = self.k_proj(hidden_states).view(hidden_shape) value = self.v_proj(hidden_states).view(hidden_shape) return query, key, value class _IPEXFalconAttention(_IPEXAttention): def __init__(self, module, device, config): self.num_key_value_heads = config.num_key_value_heads super().__init__(module, device, config) self.q_slice = self.head_dim * config.num_kv_heads self.k_slice = self.q_slice + self.head_dim self.v_slice = self.k_slice + self.head_dim def qkv_gemm(self, hidden_states): qkv_out = self.query_key_value(hidden_states) if self.new_decoder_architecture: qkv_out = qkv_out.view( qkv_out.shape[0], -1, self.num_attention_heads // self.num_kv_heads + 2, self.head_dim ) query = qkv_out[:, :, :-2, :].flatten(1, 2) key = qkv_out[:, :, [-2], :].flatten(1, 2) value = qkv_out[:, :, [-1], :].flatten(1, 2) else: query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim) key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, device, config) -> None: super().__init__(module, device, config) _setattr_from_module(self, module) if not config.compile and getattr(config, "quantization_config", None) is None: self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) self.c_attn_linear.bias = self.c_attn.bias self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) self.c_proj_linear.bias = self.c_proj.bias if self.c_proj_linear not in ["LinearAllreduce"]: self.linear_add = LinearAdd(self.c_proj_linear) def qkv_gemm(self, hidden_states): if hasattr(self, "c_attn_linear"): query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) query = query.view(-1, self.num_attention_heads, self.head_dim) key = key.view(-1, self.num_attention_heads, self.head_dim) value = value.view(-1, self.num_attention_heads, self.head_dim) return query, key, value def rope(self, query, key, *args, **kwargs): return query, key def postprocess_attention_output(self, attn_output): if self.use_sdpa: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) if not hasattr(self, "linear_add"): attn_output = self.c_proj(attn_output) return attn_output # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 class _IPEXLlamaMLP(nn.Module): def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config self.module_device = device if not config.compile and getattr(config, "quantization_config", None) is None: # LinearAllreduce cannot use fused op LinearAdd if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mlp_linear_add = LinearAdd(module.down_proj) if isinstance(self.act_fn, nn.SiLU): self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): if hasattr(self, "linear_silu_mul"): mlp_gate = self.linear_silu_mul(hidden_states) if hasattr(self, "mlp_linear_add"): hidden_states = self.mlp_linear_add(mlp_gate, residual) else: hidden_states = self.down_proj(mlp_gate) hidden_states = residual + hidden_states else: hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) hidden_states = residual + hidden_states return hidden_states class _IPEXFalconMLP(nn.Module): def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config self.module_device = device if not config.compile and getattr(config, "quantization_config", None) is None: # LinearAllreduce cannot use fused op LinearAdd self.linear_gelu = LinearGelu(module.dense_h_to_4h) if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]: self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) def forward( self, hidden_states: torch.Tensor, attention_output: torch.Tensor = None, residual: torch.Tensor = None, **kwargs, ): if hasattr(self, "linear_gelu"): mlp_hidden_states = self.linear_gelu(hidden_states) else: mlp_hidden_states = self.act(self.dense_h_to_4h(hidden_states)) if hasattr(self, "linear_add_add"): output = self.linear_add_add(mlp_hidden_states, attention_output, residual) else: mlp_output = self.dense_4h_to_h(mlp_hidden_states) output = mlp_output + attention_output + residual return output class _IPEXGPT2MLP(nn.Module): def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config self.module_device = device if not config.compile and getattr(config, "quantization_config", None) is None: self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) self.c_fc_linear.bias = self.c_fc.bias self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) self.c_proj_linear.bias = self.c_proj.bias if self.module_device.type == "cpu": self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) if self.c_proj_linear not in ["LinearAllreduce"]: self.linear_add = LinearAdd(self.c_proj_linear) def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: if hasattr(self, "linear_new_gelu"): hidden_states = self.linear_new_gelu(hidden_states) else: hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) if not hasattr(self, "linear_add"): hidden_states = self.c_proj(hidden_states) return hidden_states # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) self.self_attn = _IPEXLlamaAttention(module.self_attn, device, config) self.mlp = _IPEXLlamaMLP(module.mlp, device, config) if getattr(config, "quantization_config", None): _remove_hooks_for_ipex(self, True) def forward(self, hidden_states: torch.Tensor, **kwargs): # Please see the original model's forward to check the parameter residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, present, attn_weights = self.self_attn(hidden_states=hidden_states, **kwargs) if hasattr(self.self_attn, "mha_linear_add"): hidden_states = self.self_attn.mha_linear_add(hidden_states, residual) else: hidden_states = self.self_attn.o_proj(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states, residual, **kwargs) outputs = (hidden_states,) if kwargs.get("output_attentions", False): outputs += (attn_weights,) if kwargs.get("use_cache", False): outputs += (present,) return outputs class _IPEXFalconDecoderLayer(nn.Module): def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) self.self_attention = _IPEXFalconAttention(module.self_attention, device, config) self.mlp = _IPEXFalconMLP(module.mlp, device, config) if getattr(config, "quantization_config", None): _remove_hooks_for_ipex(self, True) def forward(self, hidden_states: torch.Tensor, **kwargs): # Please see the original model's forward to check the parameter residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention attn_output, present, attn_weights = self.self_attention(hidden_states, **kwargs) attn_output = self.self_attention.dense(attn_output) hidden_states = self.mlp(hidden_states, attn_output, residual) outputs = (hidden_states,) if kwargs.get("output_attentions", False): outputs += (attn_weights,) if kwargs.get("use_cache", False): outputs += (present,) return outputs class _IPEXGPT2Block(nn.Module): def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) self.attn = _IPEXGPT2Attention(module.attn, device, config) self.mlp = _IPEXGPT2MLP(module.mlp, device, config) if getattr(config, "quantization_config", None): _remove_hooks_for_ipex(self, True) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, **kwargs, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection if hasattr(self.attn, "linear_add"): hidden_states = self.attn.linear_add(attn_output, residual) else: hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " "cross-attention layers by setting `config.add_cross_attention=True`" ) residual = hidden_states hidden_states = self.ln_cross_attn(hidden_states) cross_attn_outputs = self.crossattention( hidden_states, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, **kwargs, ) attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection if hasattr(self.mlp, "linear_add"): hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) else: hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] return outputs # hidden_states, present, (attentions, cross_attentions) # Currently can just apply llama decoder layer. class _IPEXQwen2DecoderLayer(_IPEXLlamaDecoderLayer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class _IPEXMistralDecoderLayer(_IPEXLlamaDecoderLayer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524 class _IPEXIntermediate(nn.Module): def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) self.module_device = device if not config.compile and getattr(config, "quantization_config", None) is None: self.linear_gelu = LinearGelu(module.dense) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if hasattr(self, "linear_gelu"): hidden_states = self.linear_gelu(hidden_states) else: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states