in src/rime/models/rnn.py [0:0]
def _collate_fn(batch, tokenize, truncated_input_steps, training):
if truncated_input_steps > 0:
batch = [seq[-truncated_input_steps:] for seq in batch]
batch = [[0] + [tokenize[x] for x in seq if x in tokenize] for seq in batch]
batch = [torch.tensor(seq, dtype=torch.int64) for seq in batch]
batch, lengths = pad_packed_sequence(pack_sequence(batch, False))
if training:
return (batch[:-1].T, batch[1:].T) # TBPTT assumes NT layout
else:
return (batch, lengths) # RNN default TN layout