def forward()

in tzrec/modules/sequence.py [0:0]


    def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Forward the module."""
        sequence = sequence_embedded[self._sequence_name]  # B, N, E
        sequence_length = sequence_embedded[self._sequence_length_name]  # N
        # max_seq_length = sequence.size(1)
        float_dtype = sequence.dtype

        # Add positional embeddings and apply dropout
        positions = (
            fx_arange(sequence.size(1), device=sequence.device)
            .unsqueeze(0)
            .expand(sequence.size(0), -1)
        )
        sequence = sequence * (self._sequence_dim**0.5) + self.position_embed(positions)
        sequence = F.dropout(sequence, p=self.dropout_rate, training=self.training)
        sequence_mask = fx_arange(
            sequence.size(1), device=sequence_length.device
        ).unsqueeze(0) < sequence_length.unsqueeze(1)
        sequence = sequence * sequence_mask.unsqueeze(-1).to(float_dtype)

        invalid_attn_mask = 1.0 - self._attn_mask.to(float_dtype)
        sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
            sequence_length
        )
        sequence = torch.ops.fbgemm.dense_to_jagged(sequence, [sequence_offsets])[0]

        all_timestamps = None
        jagged_x, cache_states = self.jagged_forward(
            x=sequence,
            x_offsets=sequence_offsets,
            all_timestamps=all_timestamps,
            invalid_attn_mask=invalid_attn_mask,
            delta_x_offsets=None,
            cache=None,
            return_cache_states=False,
        )
        # post processing: L2 Normalization
        output_embeddings = jagged_x
        output_embeddings = output_embeddings[..., : self._sequence_dim]
        output_embeddings = output_embeddings / torch.clamp(
            torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True),
            min=1e-6,
        )
        if not self.training:
            output_embeddings = self.get_current_embeddings(
                sequence_length, output_embeddings
            )
        return output_embeddings