local_gemma/attention.py (109 lines of code) (raw):

from typing import Optional, Tuple import logging import torch from torch import nn from transformers import Cache from transformers.models.gemma2.configuration_gemma2 import Gemma2Config from transformers.models.gemma2.modeling_gemma2 import Gemma2RotaryEmbedding, apply_rotary_pos_emb, repeat_kv, GEMMA2_ATTENTION_CLASSES logger = logging.getLogger(__name__) class Gemma2FusedAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified from the original implementation to include fused q/k/v projection. """ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.scaling = config.query_pre_attn_scalar**-0.5 if self.hidden_size % self.num_heads != 0: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) # fused attention proj total_head_dim = (2 * self.num_key_value_heads + self.num_heads) * self.head_dim self.qkv_proj = nn.Linear(self.hidden_size, total_head_dim, bias=config.attention_bias) # conversion from un-fused to fused self._register_load_state_dict_pre_hook(self.load_hook) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) self.rotary_emb = Gemma2RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def load_hook(self, state_dict, prefix, *args): if prefix + "q_proj.weight" in state_dict: q_proj = state_dict.pop(prefix + "q_proj.weight") k_proj = state_dict.pop(prefix + "k_proj.weight") v_proj = state_dict.pop(prefix + "v_proj.weight") state_dict[prefix + "qkv_proj.weight"] = torch.cat([q_proj, k_proj, v_proj]) if self.config.attention_bias: q_bias = state_dict.pop(prefix + "q_proj.bias") k_bias = state_dict.pop(prefix + "k_proj.bias") v_bias = state_dict.pop(prefix + "v_proj.bias") state_dict[prefix + "qkv_proj.bias"] = torch.cat([q_bias, k_bias, v_bias]) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_key_value_states = self.qkv_proj(hidden_states) query_size = self.num_heads * self.head_dim key_value_size = self.num_key_value_heads * self.head_dim query_states, key_states, value_states = query_key_value_states.split( [query_size, key_value_size, key_value_size], dim=-1 ) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, "cos": cos, "sliding_window": self.sliding_window, "cache_position": cache_position, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if self.config.attn_logit_softcapping is not None: attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.config.attn_logit_softcapping if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value