def forward()

in src/nanotron/nn/rotary.py [0:0]


    def forward(self, seq_length=None, position_offset=0, position_ids=None):
        """Generate rotary position embeddings.

        Args:
            seq_length (int, optional): Sequence length to use. Defaults to max_seq_len.
            position_offset (int, optional): Offset for position ids. Defaults to 0.
            position_ids (Tensor, optional): Position ids to use. Defaults to None. [batch_size, seq_length]

        Returns:
            Tensor: Rotary embeddings of shape [seq_length, 1, 1, dim]
        """
        self.freqs_cis = self.freqs_cis.to(torch.float)  # TODO @nouamane: Fix using `DTypeInvariantTensor` ...

        # Generate position indices
        if position_ids is not None:
            assert seq_length is None, "seq_length must be None if position_ids is provided"
            assert position_offset == 0, "position_offset must be 0 if position_ids is provided"
            # TODO @nouamane: Using position_ids means we compute redundant embeddings for same positions
            positions = position_ids.to(device=self.freqs_cis.device, dtype=self.freqs_cis.dtype)  # [b*s]
            self.max_seq_len = positions.max() + 1
        else:
            seq_length = seq_length or self.max_seq_len
            positions = (
                torch.arange(seq_length, device=self.freqs_cis.device, dtype=self.freqs_cis.dtype) + position_offset
            )  # [seq_length]
            self.max_seq_len = seq_length

        # Apply sequence length scaling if specified
        if self.seq_len_scaling_factor is not None:
            positions = positions / self.seq_len_scaling_factor

        # Compute position frequencies
        # TODO @nouamane: Using position_ids means we compute redundant embeddings for same positions. Only use them in SFT
        position_freqs = torch.outer(positions, self.freqs_cis)  # [seq_length, dim/2]

        # Organize embeddings based on interleaving strategy
        if self.fused:
            embeddings = position_freqs  # [b*s, dim/2] or [seq_length, dim/2]
        else:
            if not self.interleaved:
                embeddings = torch.cat((position_freqs, position_freqs), dim=-1)  # [b*s, dim] or [seq_length, dim]
            else:
                embeddings = torch.stack(
                    (position_freqs.view(-1, 1), position_freqs.view(-1, 1)), dim=-1
                )  # [b*s*dim, 2] or [seq_length*dim, 2]
                embeddings = embeddings.view(position_freqs.shape[0], -1)  # [b*s, dim] or [seq_length, dim]

        return embeddings  # [b*s, dim] or [seq_length, dim] or [b*s, dim/2] or [seq_length, dim/2]