def _collate_fn()

in src/rime/models/rnn.py [0:0]


def _collate_fn(batch, tokenize, truncated_input_steps, training):
    if truncated_input_steps > 0:
        batch = [seq[-truncated_input_steps:] for seq in batch]
    batch = [[0] + [tokenize[x] for x in seq if x in tokenize] for seq in batch]
    batch = [torch.tensor(seq, dtype=torch.int64) for seq in batch]
    batch, lengths = pad_packed_sequence(pack_sequence(batch, False))
    if training:
        return (batch[:-1].T, batch[1:].T)  # TBPTT assumes NT layout
    else:
        return (batch, lengths)  # RNN default TN layout