in models/language_model.py [0:0]
def apply_rotary_pos_embd(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim:int=1)-> tuple[torch.Tensor, torch.Tensor]:
"""
Applies rotary positional embeddings to query and key tensors in attention mechanisms.
Rotary positional embeddings inject position-dependent rotations into query and key vectors,
enabling transformers to encode positional information effectively without explicit positional encoding.
Args:
q (torch.Tensor): Query tensor with shape [batch_size, num_heads, seq_len, head_dim].
k (torch.Tensor): Key tensor with shape [batch_size, num_heads, seq_len, head_dim].
cos (torch.Tensor): Precomputed cosine positional embeddings with shape [batch_size, seq_len, head_dim].
sin (torch.Tensor): Precomputed sine positional embeddings with shape [batch_size, seq_len, head_dim].
unsqueeze_dim (int, optional): Dimension index to unsqueeze `cos` and `sin` to enable broadcasting.
Defaults to 1 (typically the heads dimension).
Returns:
tuple[torch.Tensor, torch.Tensor]: The rotated query and key tensors (`q_embed`, `k_embed`),
each with the same shape as the input tensors.
How it works:
- `cos` and `sin` tensors are unsqueezed at `unsqueeze_dim` to broadcast across attention heads.
- Rotary embeddings apply a complex number rotation in the embedding space using:
rotated = (original * cos) + (rotate_half(original) * sin)
- `rotate_half` performs a specific half-dimension rotation on the input tensor.
- This operation encodes relative position information in q and k without adding explicit positional vectors.
Example:
q_embed, k_embed = apply_rotary_pos_embd(q, k, cos, sin)
"""
# We need to make sure cos and sin can be properly broadcast
# to the shape of q and k by adding the heads dimension
cos = cos.unsqueeze(unsqueeze_dim) # [batch_size, 1, seq_len, head_dim]
sin = sin.unsqueeze(unsqueeze_dim) # [batch_size, 1, seq_len, head_dim]
# Apply complex multiplication:
# (q * cos) + (rotate_half(q) * sin)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed