in src/nanotron/nn/rotary.py [0:0]
def forward(self, seq_length=None, position_offset=0, position_ids=None):
"""Generate rotary position embeddings.
Args:
seq_length (int, optional): Sequence length to use. Defaults to max_seq_len.
position_offset (int, optional): Offset for position ids. Defaults to 0.
position_ids (Tensor, optional): Position ids to use. Defaults to None. [batch_size, seq_length]
Returns:
Tensor: Rotary embeddings of shape [seq_length, 1, 1, dim]
"""
self.freqs_cis = self.freqs_cis.to(torch.float) # TODO @nouamane: Fix using `DTypeInvariantTensor` ...
# Generate position indices
if position_ids is not None:
assert seq_length is None, "seq_length must be None if position_ids is provided"
assert position_offset == 0, "position_offset must be 0 if position_ids is provided"
# TODO @nouamane: Using position_ids means we compute redundant embeddings for same positions
positions = position_ids.to(device=self.freqs_cis.device, dtype=self.freqs_cis.dtype) # [b*s]
self.max_seq_len = positions.max() + 1
else:
seq_length = seq_length or self.max_seq_len
positions = (
torch.arange(seq_length, device=self.freqs_cis.device, dtype=self.freqs_cis.dtype) + position_offset
) # [seq_length]
self.max_seq_len = seq_length
# Apply sequence length scaling if specified
if self.seq_len_scaling_factor is not None:
positions = positions / self.seq_len_scaling_factor
# Compute position frequencies
# TODO @nouamane: Using position_ids means we compute redundant embeddings for same positions. Only use them in SFT
position_freqs = torch.outer(positions, self.freqs_cis) # [seq_length, dim/2]
# Organize embeddings based on interleaving strategy
if self.fused:
embeddings = position_freqs # [b*s, dim/2] or [seq_length, dim/2]
else:
if not self.interleaved:
embeddings = torch.cat((position_freqs, position_freqs), dim=-1) # [b*s, dim] or [seq_length, dim]
else:
embeddings = torch.stack(
(position_freqs.view(-1, 1), position_freqs.view(-1, 1)), dim=-1
) # [b*s*dim, 2] or [seq_length*dim, 2]
embeddings = embeddings.view(position_freqs.shape[0], -1) # [b*s, dim] or [seq_length, dim]
return embeddings # [b*s, dim] or [seq_length, dim] or [b*s, dim/2] or [seq_length, dim/2]