in src/speech_reps/models/decoar.py [0:0]
def forward(self, features, padding_mask=None):
max_seq_len = features.shape[1]
features = self.post_extract_proj(features)
if padding_mask is not None:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
seq_lengths = (~padding_mask).sum(dim=-1).tolist()
packed_rnn_inputs = pack_padded_sequence(features, seq_lengths,
batch_first=True,
enforce_sorted=False)
packed_rnn_outputs, _ = self.forward_lstm(packed_rnn_inputs)
x_forward, _ = pad_packed_sequence(packed_rnn_outputs,
batch_first=True,
total_length=max_seq_len)
packed_rnn_inputs = pack_padded_sequence(self.flipBatch(features, seq_lengths), seq_lengths,
batch_first=True,
enforce_sorted=False)
packed_rnn_outputs, _ = self.backward_lstm(packed_rnn_inputs)
x_backward, _ = pad_packed_sequence(packed_rnn_outputs,
batch_first=True,
total_length=max_seq_len)
x_backward = self.flipBatch(x_backward, seq_lengths)
return torch.cat((x_forward, x_backward), dim=-1)