in tzrec/modules/sequence.py [0:0]
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
query = sequence_embedded[self._query_name]
sequence = sequence_embedded[self._sequence_name]
sequence_length = sequence_embedded[self._sequence_length_name]
max_seq_length = sequence.size(1)
sequence_mask = fx_arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
if self._query_dim < self._sequence_dim:
query = F.pad(query, (0, self._sequence_dim - self._query_dim))
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1) # [B, T, C]
attn_input = torch.cat([sequence, queries * sequence, queries], dim=-1)
attn_output = self.mlp(attn_input)
attn_output = self.linear(attn_output)
attn_output = self.active(attn_output) # [B, T, 1]
att_sequences = attn_output * sequence_mask.unsqueeze(2) * sequence
pad = (0, 0, 0, self._sum_windows_len - max_seq_length)
pad_att_sequences = F.pad(att_sequences, pad).transpose(0, 1)
result = torch.segment_reduce(
pad_att_sequences, reduce="sum", lengths=self.windows_len, axis=0
).transpose(0, 1) # [B, L, C]
segment_length = torch.min(
sequence_length.unsqueeze(1) - self.cumsum_windows_len.unsqueeze(0),
self.windows_len,
)
result = result / torch.max(
segment_length, torch.ones_like(segment_length)
).unsqueeze(2)
return torch.cat([result, query.unsqueeze(1)], dim=1).reshape(
result.shape[0], -1
) # [B, (L+1)*C]