in models/language_model.py [0:0]
def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute rotary positional embeddings (cosine and sine components).
Args:
position_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing position indices.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of two tensors (cos, sin), each of shape
(batch_size, seq_len, dim), representing rotary embeddings.
"""
batch_size, seq_len = position_ids.shape
# Dynamic scaling for longer sequences
# Divide the angle frequency to fit more rotation into the embedding space.
max_seq = position_ids.max() + 1
if max_seq > self.original_max_seq_len:
scale = max_seq / self.original_max_seq_len
inv_freq = self.inv_freq / scale
else:
inv_freq = self.inv_freq
# Compute theta = position * frequency
# Flatten position_ids for batch processing
flat_position_ids = position_ids.reshape(-1).float()
# Element-wise outer product: [seq_len] x [dim/2] => [seq_len, dim/2]
freqs = flat_position_ids.unsqueeze(-1) * inv_freq.unsqueeze(0)
# Reshape to include batch dimension
freqs = freqs.reshape(batch_size, seq_len, -1)
# Now create interleaved pattern
emb = torch.cat([freqs, freqs], dim=-1)
# Compute cos and sin
cos = torch.cos(emb) * self.attention_scaling
sin = torch.sin(emb) * self.attention_scaling
return cos, sin