in models/language_model.py [0:0]
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attention_mask=None, block_kv_cache=None) -> tuple[torch.Tensor, dict]:
"""
Forward pass for grouped query attention.
Args:
x (Tensor): Input tensor of shape (B, T_curr, C), where
B = batch size,
T_curr = current sequence length,
C = embedding dimension.
cos (Tensor): Rotary embedding cosines, shape compatible with q and k.
sin (Tensor): Rotary embedding sines, shape compatible with q and k.
attention_mask (Tensor, optional): Attention mask tensor of shape (B, total_kv_length),
with 1 for tokens to attend to and 0 for padding.
block_kv_cache (dict, optional): Cache dict with 'key' and 'value' tensors for autoregressive decoding.
Returns:
tuple[Tensor, dict]:
- Output tensor after attention and projection, shape (B, T_curr, C).
- Updated block_kv_cache dict for caching key-value states.
"""
is_prefill = block_kv_cache is None
B, T_curr, C = x.size() # T_curr is the sequence length of the current input x
q_curr = self.q_proj(x).view(B, T_curr, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T_curr, head_dim)
k_curr = self.k_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim)
v_curr = self.v_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim)
# Apply rotary embeddings to the current q and k
q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)
# Check if we can use cached keys and values
if not is_prefill and block_kv_cache['key'] is not None:
# Concatenate with cached K, V
# k_rotated and v_curr are for the new token(s)
k = block_kv_cache['key']
v = block_kv_cache['value']
k = torch.cat([k, k_rotated], dim=2)
v = torch.cat([v, v_curr], dim=2)
block_kv_cache['key'] = k
block_kv_cache['value'] = v
else:
# No cache, this is the first pass (prefill)
k = k_rotated
v = v_curr
block_kv_cache = {'key': k, 'value': v}
# Repeat K, V for Grouped Query Attention
k_exp = k.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim)
v_exp = v.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim)
T_kv = k_exp.size(2) # Total sequence length of keys/values
# Prepare attention mask for SDPA or manual path
# attention_mask is (B, T_kv_total_length), 1 for attend, 0 for pad
additive_attn_mask = None
if attention_mask is not None:
# The current `attention_mask` parameter is assumed to be `[B, total_sequence_length_kv]`
# Let's make it `[B, 1, 1, T_kv]` for SDPA.
mask_for_keys = attention_mask[:, :T_kv] # Ensure mask matches key length [B, T_kv]
additive_attn_mask = (1.0 - mask_for_keys.unsqueeze(1).unsqueeze(2).float()) * torch.finfo(q.dtype).min
# This additive_attn_mask shape is [B, 1, 1, T_kv]
if self.sdpa and x.device.type != 'mps':
# During decode, no additional masking needed as [1, T_kv] is naturally causal
is_causal = (T_curr == T_kv and T_curr > 1)
y = torch.nn.functional.scaled_dot_product_attention(
q, k_exp, v_exp,
attn_mask=additive_attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal
)
else:
# Manual attention implementation
attn = torch.matmul(q, k_exp.transpose(2, 3)) / math.sqrt(self.head_dim) # (B, n_heads, T_curr, T_kv)
# During decode: no additional masking needed as [1, T_kv] is naturally causal
if T_curr == T_kv and T_curr > 1:
causal_mask_val = torch.tril(torch.ones(T_curr, T_curr, device=x.device, dtype=torch.bool)).view(1, 1, T_curr, T_curr)
attn = attn.masked_fill(~causal_mask_val, float('-inf'))
if additive_attn_mask is not None: # Additive padding mask
# additive_attn_mask is [B,1,1,T_kv], needs to be broadcast to [B, n_heads, T_curr, T_kv]
attn = attn + additive_attn_mask
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
y = attn @ v_exp
y = y.transpose(1, 2).contiguous().view(B, T_curr, C)
y = self.out_proj(y)
y = self.resid_dropout(y)
return y, block_kv_cache