in tzrec/modules/sequence.py [0:0]
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
sequence = sequence_embedded[self._sequence_name] # B, N, E
sequence_length = sequence_embedded[self._sequence_length_name] # N
# max_seq_length = sequence.size(1)
float_dtype = sequence.dtype
# Add positional embeddings and apply dropout
positions = (
fx_arange(sequence.size(1), device=sequence.device)
.unsqueeze(0)
.expand(sequence.size(0), -1)
)
sequence = sequence * (self._sequence_dim**0.5) + self.position_embed(positions)
sequence = F.dropout(sequence, p=self.dropout_rate, training=self.training)
sequence_mask = fx_arange(
sequence.size(1), device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
sequence = sequence * sequence_mask.unsqueeze(-1).to(float_dtype)
invalid_attn_mask = 1.0 - self._attn_mask.to(float_dtype)
sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
sequence_length
)
sequence = torch.ops.fbgemm.dense_to_jagged(sequence, [sequence_offsets])[0]
all_timestamps = None
jagged_x, cache_states = self.jagged_forward(
x=sequence,
x_offsets=sequence_offsets,
all_timestamps=all_timestamps,
invalid_attn_mask=invalid_attn_mask,
delta_x_offsets=None,
cache=None,
return_cache_states=False,
)
# post processing: L2 Normalization
output_embeddings = jagged_x
output_embeddings = output_embeddings[..., : self._sequence_dim]
output_embeddings = output_embeddings / torch.clamp(
torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True),
min=1e-6,
)
if not self.training:
output_embeddings = self.get_current_embeddings(
sequence_length, output_embeddings
)
return output_embeddings