optimum/habana/transformers/models/falcon/modeling_falcon.py (840 lines of code) (raw):

import contextlib import math import os from typing import Optional, Tuple, Union import torch try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: print("Not using HPU fused kernel for scaled_dot_product_attention") FusedSDPA = None try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE except ImportError: print("Not using HPU fused kernel for apply_rotary_pos_emb") FusedRoPE = None try: from habana_frameworks.torch.hpu import sdp_kernel SDPContext = True except ImportError: SDPContext = False import habana_frameworks.torch.core as htcore from torch import nn from torch.nn import functional as F from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) from transformers.models.falcon.configuration_falcon import FalconConfig from transformers.models.falcon.modeling_falcon import ( FalconAttention, FalconDecoderLayer, FalconForCausalLM, FalconMLP, FalconModel, apply_rotary_pos_emb, build_alibi_tensor, ) from transformers.utils import logging from ...modeling_attn_mask_utils import ( GaudiAttentionMaskConverter, _gaudi_prepare_4d_causal_attention_mask, ) from ...modeling_rope_utils import GaudiRotaryEmbedding from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module logger = logging.get_logger(__name__) def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: """ Copied from transformers.models.falcon.modeling_falcon/dropout_add https://github.com/huggingface/transformers/blob/b338a6c3b8eda29610d4d472cad8cd87cbfdaaed/src/transformers/models/falcon/modeling_falcon.py#L248 """ out = F.dropout(x, p=prob, training=training) if training: out = residual + out return out else: residual.add_(out) return residual def apply_customized_rope(q, k, cos, sin, position_ids, training=True): if q.device.type == "hpu" and FusedRoPE: return apply_customized_rope_module(q, k, cos, sin, position_ids, training) else: return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor: hidden_states = F.linear(input, self.weight, bias=self.bias) return hidden_states def repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, n_rep: int, ): """ Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ batch, num_key_value_heads, kv_len, head_dim = key_states.shape if n_rep == 1 or num_key_value_heads == 1: return query_states, key_states, value_states, attention_mask new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) key_states = key_states.reshape(new_kv_shape) value_states = value_states.reshape(new_kv_shape) batch, _, q_len, head_dim = query_states.shape new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) query_states = query_states.reshape(new_q_shape) if attention_mask is not None: # Add groups dim and set to 1 attention_mask = attention_mask.unsqueeze(1) return query_states, key_states, value_states, attention_mask # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) class Softmax(nn.Module): def __init__(self): super().__init__() def forward(self, x, dim=None, invAttnHead=None): return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) # ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention class ScaledDotProductAttention(nn.Module): def __init__(self, config: FalconConfig): super().__init__() self.head_dim = config.hidden_size // config.num_attention_heads self.bmm1 = Matmul() self.bmm2 = Matmul() self.softmax = Softmax() self.num_key_value_groups = config.num_attention_heads // config.num_kv_heads def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(self.head_dim) invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") if is_causal: assert attn_mask is None attn_bias = torch.zeros(L, S, dtype=query.dtype) temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) query, key, value, attn_mask = repeat_kv(query, key, value, attn_mask, self.num_key_value_groups) attn_weight = self.bmm1(query, key.transpose(-2, -1)) attn_weight += attn_mask attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) attn_output = self.bmm2(attn_weight, value) return attn_output class GaudiFalconAttention(FalconAttention): """ Inherits from FalconAttention: https://github.com/huggingface/transformers/blob/838b87abe231fd70be5132088d0dee72a7bb8d62/src/transformers/models/falcon/modeling_falcon.py#L267 The only differences are: - add new args token_idx and position_ids - replace F.scaled_dot_product_attention with Habana torch's version for BF16 - use ScaledDotProductAttention for FP8 quantization - add new arg reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask Choice of SDPA: There are these variables: use_flash_attention and datatype (bf16/fp8) datatype is determined by presence of QUANT_CONFIG env var, presence of which indicates fp8 1. use_flash_attention, fp8: use ModuleFusedSDPA. most optimal 2. use_flash_attention, bf16: use FusedSDPA 3. not use_flash_attention, fp8: Use ScaledDotProductAttention, along with QUANT_CONFIG. This is the case before this PR 4. not use_flash_attention, bf16: F.scaled_dot_product_attention. Slowest option """ def __init__(self, config: FalconConfig, layer_idx=None): super().__init__(config, layer_idx) self.is_fp8 = os.getenv("QUANT_CONFIG", "") != "" # In the constructor we do not know which one we will need later in the forward, so creating both # TODO, Does this affect memory usage? if self.is_fp8: self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config) self.k_cache = KVCache() self.v_cache = KVCache() self.inp_seq_len = -1 self.max_position_embeddings = config.max_position_embeddings self.rotary_emb = GaudiRotaryEmbedding(config=self.config) def _split_heads( self, fused_qkv: torch.Tensor, broadcast: Optional[bool] = True ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.new_decoder_architecture: batch, seq_len, _ = fused_qkv.shape if self.config.num_attention_heads != self.num_heads: # When DS divides heads for TP num_heads = self.config.num_attention_heads num_kv_heads = self.config.num_kv_heads else: # When DS not in use num_heads = self.num_heads num_kv_heads = self.num_kv_heads qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, self.head_dim) # query = qkv[:, :, :, :-2] # key = qkv[:, :, :, [-2]] # value = qkv[:, :, :, [-1]] d3 = qkv.shape[3] - 2 query = torch.index_select(qkv, 3, index=torch.arange(d3, device=qkv.device)) key = torch.index_select(qkv, 3, index=torch.tensor([d3], device=qkv.device)) value = torch.index_select(qkv, 3, index=torch.tensor([d3 + 1], device=qkv.device)) if broadcast: key = torch.broadcast_to(key, query.shape) value = torch.broadcast_to(value, query.shape) query, key, value = [x.flatten(2, 3) for x in (query, key, value)] return query, key, value elif not self.multi_query: batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) # TODO : Need to be fixed to use index_select() return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] else: batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) # return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] d2 = fused_qkv.shape[2] - 2 query = torch.index_select(fused_qkv, 2, index=torch.arange(d2, device=fused_qkv.device)) key = torch.index_select(fused_qkv, 2, index=torch.tensor([d2], device=fused_qkv.device)) value = torch.index_select(fused_qkv, 2, index=torch.tensor([d2 + 1], device=fused_qkv.device)) return query, key, value def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): if self.config.new_decoder_architecture: cache_shape = (batch_size, self.num_kv_heads, max_seq_len, self.head_dim) else: cache_shape = (batch_size, 1, max_seq_len, self.head_dim) device = self.query_key_value.weight.device dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings # This helps in avoiding creation of these caches during actual model forward pass and # reduce memory consumption and improve performance. if seq_len > self.max_position_embeddings: self.max_position_embeddings = seq_len self.rotary_emb._set_cos_sin_cache( seq_len, self.query_key_value.weight.device, self.query_key_value.weight.dtype ) def pre_attn_forward( self, hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None (query_layer, key_layer, value_layer) = self._split_heads( fused_qkv, not use_flash_attention and not self.is_fp8 and not train_with_flash_attention and not (self.config.num_kv_heads == 8), ) batch_size, query_length, _, _ = query_layer.shape query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) kv_seq_len = key_layer.shape[-2] if layer_past is not None: if token_idx is not None: if reuse_cache: kv_seq_len = layer_past[0][-2] else: kv_seq_len = layer_past[0].shape[-2] else: kv_seq_len += layer_past[0].shape[-2] if alibi is None: cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = apply_customized_rope( query_layer, key_layer, cos, sin, position_ids, self.training ) if use_cache: if self.training: present = None else: if reuse_cache: key_layer = self.k_cache(key_layer, -2, token_idx) value_layer = self.v_cache(value_layer, -2, token_idx) present = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if layer_past is None: past_key = torch.zeros( key_layer.shape, dtype=self.query_key_value.weight.dtype, device=self.query_key_value.weight.device, ) past_value = torch.zeros( key_layer.shape, dtype=self.query_key_value.weight.dtype, device=self.query_key_value.weight.device, ) layer_past = [past_key, past_value] key_layer = self.k_cache.update( layer_past[0], key_layer, -2, token_idx, self.inp_seq_len ) # k_layer bs*1, q_len, head_dim value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) if token_idx is None: layer_past = (key_layer, value_layer) present = layer_past if cache_idx is not None and query_length == 1: key_layer = key_layer[:, :, :cache_idx, :] value_layer = value_layer[:, :, :cache_idx, :] attention_mask = attention_mask[:, :, :, :cache_idx] else: present = None if self.training or present is None: kv_length = key_layer.shape[-2] else: kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] if (not reuse_cache) and (token_idx is not None) and (cache_idx is not None) and (query_length == 1): # Return only past key value shapes and not the tensors during decode phase (q len is 1) # to avoid making past key values as persistent output tensors of HPU graphs. present = (present[0].shape, present[1].shape) if alibi is None: # both train/inference if output_attentions: attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). attn_output = attention_scores @ value_layer else: if use_flash_attention or train_with_flash_attention: is_causal = self.is_causal and query_length > 1 and flash_attention_causal_mask if self.is_fp8: attn_mask = None if is_causal else attention_mask flash_attention_fast_softmax = True # TODO pass this along softmax_mode = "fast" if flash_attention_fast_softmax else "None" enable_recompute = self.is_fp8 if query_length == 1 else flash_attention_recompute with sdp_kernel(enable_recompute=enable_recompute): attn_output = self.fused_scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask, 0.0, is_causal, None, softmax_mode ) else: # TODO very similar to the fp8 case above, could be merged. with ( sdp_kernel(enable_recompute=flash_attention_recompute) if SDPContext else contextlib.nullcontext() ): attn_output = FusedSDPA.apply( query_layer, key_layer, value_layer, attention_mask, 0.0, # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. is_causal and attention_mask is None, ) else: if self.is_fp8: attn_output = self.unfused_scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False ) else: if query_layer.shape != key_layer.shape: query_layer, key_layer, value_layer, attention_mask = repeat_kv( query_layer, key_layer, value_layer, attention_mask, self.config.num_attention_heads // self.config.num_kv_heads, ) # Workaround util scaled_dot_product_attention support broadcast. if self.training is True and query_layer.shape != key_layer.shape: key_layer = torch.broadcast_to(key_layer, query_layer.shape) value_layer = torch.broadcast_to(value_layer, query_layer.shape) attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask, 0.0, # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. is_causal=self.is_causal and attention_mask is None and query_length > 1, ) # Performance improvement for HPU if self.training is True and htcore: htcore.mark_step() attention_scores = None attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) attn_output = attn_output.permute(0, 2, 1, 3) attn_output = attn_output.reshape(batch_size, query_length, -1) attn_output = self.dense(attn_output) if output_attentions: return attn_output, present, attention_scores else: return attn_output, present, _ else: if train_with_flash_attention: if FusedSDPA: # TODO needs to be turned into a module for quantization with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_layer, key_layer, value_layer, attention_mask, self.attention_dropout.p if self.training else 0.0, self.is_causal and attention_mask is None and query_length > 1, ) else: attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, is_causal=self.is_causal and attention_mask is None and query_length > 1, ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) attn_output = self.dense(attn_output) else: matmul_result = query_layer @ key_layer.transpose(-1, -2) # change view to [batch_size, num_heads, q_length, kv_length] attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` if input_dtype == torch.float16 or input_dtype == torch.bfloat16: attention_scores = attention_scores.to(torch.float32) attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) attention_logits *= self.inv_norm_factor attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask # change view [batch_size, num_heads, q_length, kv_length] attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) # change view [batch_size, q_length, num_heads * head_dim] attn_output = self._merge_heads(attn_output) attn_output = self.dense(attn_output) if output_attentions: return attn_output, present, attention_probs else: return attn_output, present, _ def attention_all_reduce(self, attn_output): if hasattr(self.dense, "all_reduce"): self.dense.all_reduce(attn_output) def post_attn_forward(self, attn_output): if hasattr(self.dense, "all_reduce"): return self.dense.post_all_reduce(attn_output) return attn_output class GaudiFalconMLP(FalconMLP): """ Inherits from FalconMLP: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py """ def pre_mlp_forward(self, x): x = self.act(self.dense_h_to_4h(x)) x = self.dense_4h_to_h(x) return x def mlp_all_reduce(self, x): if hasattr(self.dense_4h_to_h, "all_reduce"): self.dense_4h_to_h.all_reduce(x) def post_mlp_forward(self, x): if hasattr(self.dense_4h_to_h, "all_reduce"): return self.dense_4h_to_h.post_all_reduce(x) return x class GaudiFalconDecoderLayer(FalconDecoderLayer): """ Inherits from FalconDecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py The only differences are: - add new args token_idx and position_ids - add token_idx and position_ids into attention inputs - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask """ def __init__(self, config: FalconConfig, layer_idx=None): super().__init__(config, layer_idx=layer_idx) self.self_attention = GaudiFalconAttention(config, layer_idx) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.self_attention.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def update_sincos_cache(self, seq_len): self.self_attention.update_sincos_cache(seq_len) def forward( self, hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, **kwargs, ): residual = hidden_states ( hidden_states, present, attn_scores, attention_layernorm_out, mlp_layernorm_out, ) = self.pre_attn( # layernorm + attention before AllReduce hidden_states, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, position_embeddings=position_embeddings, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, ) self.self_attention.attention_all_reduce(hidden_states) hidden_states = self.self_attention.post_attn_forward(hidden_states) attention_output = hidden_states if not self.config.new_decoder_architecture: if self.config.parallel_attn: mlp_layernorm_out = attention_layernorm_out else: residual = dropout_add( attention_output, residual, self.config.attention_dropout, training=self.training ) mlp_layernorm_out = self.post_attention_layernorm(residual) if ( self.config.new_decoder_architecture and self.config.parallel_attn and self.config.num_ln_in_parallel_attn == 1 ): mlp_layernorm_out = attention_layernorm_out outputs = (present, attn_scores) hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) self.mlp.mlp_all_reduce(hidden_states) hidden_states = self.mlp.post_mlp_forward(hidden_states) if self.config.new_decoder_architecture or self.config.parallel_attn: hidden_states += attention_output output = dropout_add(hidden_states, residual, self.config.hidden_dropout, training=self.training) if use_cache: outputs = (output,) + outputs else: outputs = (output,) + outputs[1:] return outputs # hidden_states, present, attentions def pre_attn( self, hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, ): if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: attention_layernorm_out = self.input_layernorm(hidden_states) mlp_layernorm_out = None # Self attention. attn_outputs, present, attn_scores = self.self_attention.pre_attn_forward( attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, position_embeddings=position_embeddings, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, ) return attn_outputs, present, attn_scores, attention_layernorm_out, mlp_layernorm_out class GaudiFalconModel(FalconModel): """ Inherits from FalconModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py The only differences are: - add new args token_idx and position_ids - add token_idx and position_ids into decoder inputs - add new arg reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.h: layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def update_sincos_cache(self, seq_len): for layer in self.h: layer.update_sincos_cache(seq_len) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You must specify exactly one of input_ids or inputs_embeds") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if past_key_values is None: past_key_values = tuple([None] * len(self.h)) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 if past_key_values[0] is not None and token_idx is None: if reuse_cache: past_key_values_length = past_key_values[0][0][-2] else: past_key_values_length = past_key_values[0][0].shape[-2] if self.use_alibi: mask = ( torch.ones( (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long ) if attention_mask is None else attention_mask ) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0) # TODO: Due to perf degradation, disable spda_attn_mask use_sdpa_attn_mask = False if self._use_sdpa and not output_attentions and use_sdpa_attn_mask: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. if alibi is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) elif head_mask is None: alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) # We take care to integrate alibi bias in the attention_mask here. min_dtype = torch.finfo(alibi.dtype).min attention_mask = torch.masked_fill( alibi / math.sqrt(self.config.hidden_size // self.num_heads), attention_mask < -1, min_dtype, ) # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 if seq_length > 1: attention_mask = GaudiAttentionMaskConverter._unmask_unattended( attention_mask, min_dtype=min_dtype ) else: # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) else: # 4d mask is passed through the layers attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) position_embeddings = None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, alibi, attention_mask, position_ids, head_mask[i], layer_past, use_cache, output_attentions, cache_position, position_embeddings, None, use_flash_attention, flash_attention_recompute, flash_attention_causal_mask, ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, position_embeddings=position_embeddings, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class GaudiFalconForCausalLM(FalconForCausalLM): """ Inherits from FalconForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py The only differences are: - add new args token_idx and position_ids - add token_idx and position_ids into model inputs - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx - add new args reuse_cache - add use_flash_attention - add flash_attention_recompute - add flash_attention_causal_mask """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.transformer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) self.kv_cache_len = max_seq_len def update_sincos_cache(self, seq_len): self.transformer.update_sincos_cache(seq_len) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[Union[Cache, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: bool = True, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> dict: reuse_cache = kwargs.get("reuse_cache") bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] elif (reuse_cache or bucket_internal) and token_idx is not None: # KV cache is pre allocated with reuse cache or will be padded with bucket internal # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. if ( not self.transformer.use_alibi and attention_mask is not None and position_ids is None and token_idx is not None ): # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "reuse_cache": reuse_cache, "cache_idx": kwargs.get("cache_idx"), "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), } ) return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, cache_idx: int = None, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` num_logits_to_keep (`int`, *optional*): Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if use_flash_attention: assert FusedSDPA, ( "`use_flash_attention` is True, but cannot find FusedSDPA. Please import it as `from habana_frameworks.torch.hpex.kernels import FusedSDPA` or set use_flash_attention to False (at the expense of a possible performance degradation)." ) if flash_attention_recompute: assert use_flash_attention, "flash_attention_recompute is set, but use_flash_attention is not" if flash_attention_causal_mask: assert use_flash_attention, "flash_attention_causal_mask is set, but use_flash_attention is not" transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = transformer_outputs[0] _, seq_len, _ = hidden_states.shape if seq_len > 1 and trim_logits and not self.training: if token_idx is not None: hidden_states = hidden_states.index_select(1, token_idx - 1) else: hidden_states = hidden_states[:, -1:, :] slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep lm_logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( lm_logits, labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )