optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py (627 lines of code) (raw):

# coding=utf-8 # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### # Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company ############################################################################### import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( GPTBigCodeAttention, GPTBigCodeForCausalLM, upcast_masked_softmax, upcast_softmax, ) from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None import habana_frameworks.torch.core as htcore # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, enable_recompute): return self._hpu_kernel_fsdpa.apply( query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, enable_recompute ) class GaudiGPTBigCodeAttention(GPTBigCodeAttention): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__(config, is_cross_attention, layer_idx) self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA is not None else None self.block_size = 4096 def _attn(self, query, key, value, attention_mask=None, head_mask=None): """ This method should be deleted when https://github.com/huggingface/transformers/pull/34508 is merged. Copied from GPTBigCodeAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - in self._attn, use torch.matmul instead of torch.baddbmm when the device used for query is not cpu """ dtype = query.dtype softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype upcast = dtype != softmax_dtype unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 scale_factor = unscale**-1 if self.scale_attn_weights: scale_factor /= self.head_dim**0.5 # MQA models: (batch_size, query_length, num_heads * head_dim) # MHA models: (batch_size, num_heads, query_length, head_dim) query_shape = query.shape batch_size = query_shape[0] key_length = key.size(-1) if self.multi_query: # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) # -> (batch_size, query_length, num_heads, key_length) query_length = query_shape[1] attn_shape = (batch_size, query_length, self.num_heads, key_length) attn_view = (batch_size, query_length * self.num_heads, key_length) # No copy needed for MQA 2, or when layer_past is provided. query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) else: # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) # -> (batch_size, num_heads, query_length, key_length) query_length = query_shape[2] attn_shape = (batch_size, self.num_heads, query_length, key_length) attn_view = (batch_size * self.num_heads, query_length, key_length) # Always copies query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) # No copy when layer_past is provided. key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, # but the fix has not been released as of pytorch version 2.0.0. attn_weights = torch.zeros_like(attn_weights) attn_weights = torch.baddbmm(attn_weights, query, key, beta=1, alpha=scale_factor).view(attn_shape) else: # Formula for torch.baddbmm: out = beta * attn_weights + scale_factor * (query ⋅ key) # for beta = 0, it simplifies to: out = scale_factor * (query ⋅ key) attn_weights = (torch.matmul(query, key) * scale_factor).view(attn_shape) if upcast: # Use a fused kernel to prevent a large overhead from casting and scaling. # Sub-optimal when the key length is not a multiple of 8. if attention_mask is None: attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) else: mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) else: if attention_mask is not None: mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. attn_weights = torch.where(attention_mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: if self.multi_query: head_mask = head_mask.transpose(1, 2) attn_weights = attn_weights * head_mask if self.multi_query: attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) else: attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights def gaudi_flash_attn_v1( self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_causal, scale, softmax_mode, enable_recompute, q_block_size, ): """ Gaudi version of Flash Attention V1 to support long sequence at prompt phase Causal mask is not supported in this optimization """ if is_causal: raise ValueError("Causal mask is not supported for long input sequences") q_len = query_layer.size(-2) q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) q_padding = q_tiles * q_block_size - q_len query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) if attention_mask is not None: attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) row_o_list = [] for i in range(q_tiles): s, e = i * q_block_size, (i + 1) * q_block_size row_q = query_layer[:, :, s:e, :] row_mask = attention_mask[:, :, s:e, :] attn_output_partial = self.fused_scaled_dot_product_attention( row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode, enable_recompute ) row_o_list.append(attn_output_partial) attn_output = torch.cat(row_o_list, dim=-2) if q_padding != 0: attn_output = attn_output[:, :, :-q_padding, :] return attn_output def apply_FusedSDPA( self, query, key, value, attention_mask=None, flash_attention_recompute=False, flash_attention_fast_softmax=False, flash_attention_causal_mask=False, ): """ Copied from GPTBigCodeSdpaAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - replaced torch.nn.functional.scaled_dot_product_attention with Habana's FusedSDPA - removed WA for key and value tensor expanding over heads dimension. That WA also works but dramatically drops throughput - added args use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask to control parameters of FusedSDPA - added special case handling for input larger 8192 with function gaudi_flash_attn_v1 """ scale = None if not self.scale_attn_weights: scale = 1 # MQA models: (batch_size, query_length, num_heads * head_dim) # MHA models: (batch_size, num_heads, query_length, head_dim) query_shape = query.shape batch_size = query_shape[0] if self.multi_query: query_length = query_shape[1] # SDPA requires the dimension [..., sequence_length, head_dim]. query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. key = key.unsqueeze(1) value = value.unsqueeze(1) else: query_length = query_shape[-1] if attention_mask is not None: query = query.contiguous() key = key.contiguous() value = value.contiguous() sdpa_result = None enable_recompute = flash_attention_recompute and query_length > 1 if query_length > 1 and flash_attention_causal_mask: attention_mask = None use_causal_mask = True else: use_causal_mask = self.is_causal and attention_mask is None and query_length > 1 if query_length > 8192: sdpa_result = self.gaudi_flash_attn_v1( query, key, value, attention_mask, self.attn_pdrop if self.training else 0.0, use_causal_mask, scale, "fast" if flash_attention_fast_softmax else "None", enable_recompute, self.block_size, ) htcore.mark_step() else: sdpa_result = self.fused_scaled_dot_product_attention( query, key, value, attention_mask, self.attn_pdrop if self.training else 0.0, use_causal_mask, scale, "fast" if flash_attention_fast_softmax else "None", enable_recompute, ) if self.multi_query: # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) sdpa_result = sdpa_result.transpose(1, 2) # Reshape is kind of expensive here, as it does a memory copy, # but I did not manage to make away without it (logits do not match when using view) # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) sdpa_result = sdpa_result.reshape(query_shape) return sdpa_result, None def forward( self, hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, cache_idx: Optional[int] = None, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: """ Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - optimize KV cache """ if use_flash_attention: assert self.fused_scaled_dot_product_attention is not None, ( "Can't load HPU fused scaled dot-product attention kernel. Please retry without flash attention" ) if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." ) query = self.q_attn(hidden_states) key_value = self.c_attn(encoder_hidden_states) attention_mask = encoder_attention_mask elif self.multi_query: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), # i.e., the memory layout is not the same as GPT2. # This makes the concatenation with past_key_value more efficient. query, key_value = ( self.c_attn(hidden_states) .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) .transpose(1, 2) .split((self.head_dim, 2 * self.head_dim), dim=3) ) key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) _, q_len, _ = hidden_states.size() bucket_internal_decode_stage = cache_idx is not None and q_len == 1 if not bucket_internal_decode_stage: if layer_past is not None: past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1) if token_idx is not None: # Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled. key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1)) value = past_value.index_add( 1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1) ) else: key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) present = torch.cat((key, value), dim=-1) if use_cache else None else: assert token_idx is not None, "Invalid parameters: token_idx is None at decode stage with bucket_internal" assert layer_past is not None, ( "Invalid parameters: layer_past is None at decode stage with bucket_internal" ) past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1) key = past_key.index_copy_(1, token_idx - 1, key) value = past_value.index_copy_(1, token_idx - 1, value) present = layer_past if bucket_internal_decode_stage: key = key[:, :cache_idx, :] value = value[:, :cache_idx, :] attention_mask = attention_mask[:, :, :, :cache_idx] if not output_attentions and head_mask is None and use_flash_attention: # Difference with the original implementation: there is no need to transpose the key here, # as SDPA expects seq_length to be at index -2 for the key as well attn_output, attn_weights = self.apply_FusedSDPA( query, key, value, attention_mask, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask, ) else: attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) if bucket_internal_decode_stage: # Return only past key value shapes and not the tensors during decode phase (q len is 1) # to avoid making past key values as persistent output tensors of HPU graphs. present = present.shape outputs = (attn_output, present) if output_attentions: if self.multi_query: # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) attn_weights = attn_weights.transpose(1, 2) outputs += (attn_weights,) return outputs # a, present, (attentions) def gaudi_gpt_bigcode_block_forward( self, hidden_states: Optional[Tuple[torch.Tensor]], layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, cache_idx: Optional[int] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask """ residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_fast_softmax=flash_attention_fast_softmax, flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " "cross-attention layers by setting `config.add_cross_attention=True`" ) residual = hidden_states hidden_states = self.ln_cross_attn(hidden_states) cross_attn_outputs = self.crossattention( hidden_states, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] return outputs # hidden_states, present, (attentions, cross_attentions) def gaudi_gpt_bigcode_model_forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, cache_idx: Optional[int] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: """ Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - if token_idx and past_key_values are passed, set self_attention_mask based on the static shape of past_key_values """ # This flag used for correct tensors reshape for attention kernel self._use_sdpa = use_flash_attention output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") device = input_ids.device if input_ids is not None else inputs_embeds.device if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0].size(-2) if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_length > 0: position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] elif position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) # Self-attention mask. query_length = input_shape[-1] key_length = past_length + query_length if past_length > 0 and token_idx is not None: self_attention_mask = self.bias[None, past_length - 1 : past_length, :past_length] else: self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] if attention_mask is not None: self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( dtype=torch.bool, device=self_attention_mask.device ) # MQA models: (batch_size, query_length, n_heads, key_length) # MHA models: (batch_size, n_heads, query_length, key_length) self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) if self._use_sdpa and head_mask is None and not output_attentions: # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. dtype = self.wte.weight.dtype min_dtype = torch.finfo(dtype).min self_attention_mask = torch.where( self_attention_mask, torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), ) # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. if self.multi_query: # gpt_bigcode using MQA has the bad taste to use a causal mask with shape # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. self_attention_mask = self_attention_mask.transpose(1, 2) if query_length > 1 and attention_mask is not None: # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 self_attention_mask = GaudiAttentionMaskConverter._unmask_unattended( self_attention_mask, min_dtype=min_dtype ) attention_mask = self_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.add_cross_attention and encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_attention_mask.dim() == 2: encoder_attention_mask.unsqueeze(1) assert encoder_attention_mask.dim() == 3 encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) else: encoder_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device) if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) presents = [] if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, None, ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_fast_softmax=flash_attention_fast_softmax, flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, ) hidden_states = outputs[0] if use_cache: presents.append(outputs[1]) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class GaudiGPTBigCodeForCausalLM(GPTBigCodeForCausalLM): """ Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - add token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask into model_inputs - when KV cache is enabled, slice next_input_ids from input_ids based on the token_idx - when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx - support for internal bucketing """ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, token_idx=None, **kwargs ): token_type_ids = kwargs.get("token_type_ids", None) bucket_internal = kwargs.get("bucket_internal", False) # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 input_ids = torch.index_select(input_ids, 1, idx) if token_type_ids is not None: token_type_ids = torch.index_select(token_type_ids, 1, token_idx - 1) else: if self.config.multi_query: past_length = past_key_values[0].shape[1] else: past_length = past_key_values[0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if past_key_values is None and bucket_internal and token_idx is not None: # KV cache will be padded with bucket internal hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "token_idx": token_idx, "use_flash_attention": kwargs.get("use_flash_attention", False), "flash_attention_recompute": kwargs.get("flash_attention_recompute", False), "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False), "cache_idx": kwargs.get("cache_idx", None), } ) return model_inputs def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, cache_idx: Optional[int] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_fast_softmax=flash_attention_fast_softmax, flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: loss = self.loss_function( lm_logits, labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, )