def pack_for_rnn_seq()

in src/flint/torch_util.py [0:0]


def pack_for_rnn_seq(inputs, lengths, batch_first=True, states=None):
    """
    :param states: [rnn.num_layers, batch_size, rnn.hidden_size]
    :param inputs: Shape of the input should be [B, T, D] if batch_first else [T, B, D].
    :param lengths:  [B]
    :param batch_first:
    :return:
    """
    if not batch_first:
        _, sorted_indices = lengths.sort()
        '''
            Reverse to decreasing order
        '''
        r_index = reversed(list(sorted_indices))

        s_inputs_list = []
        lengths_list = []
        reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)

        for j, i in enumerate(r_index):
            s_inputs_list.append(inputs[:, i, :].unsqueeze(1))
            lengths_list.append(lengths[i])
            reverse_indices[i] = j

        reverse_indices = list(reverse_indices)

        s_inputs = torch.cat(s_inputs_list, 1)
        packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list)

        return packed_seq, reverse_indices

    else:
        _, sorted_indices = lengths.sort()
        '''
            Reverse to decreasing order
        '''
        r_index = reversed(list(sorted_indices))

        s_inputs_list = []
        lengths_list = []
        reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)

        if states is None:
            states = ()
        elif not isinstance(states, tuple):
            states = (states,)  # rnn.num_layers, batch_size, rnn.hidden_size

        states_lists = tuple([] for _ in states)

        for j, i in enumerate(r_index):
            s_inputs_list.append(inputs[i, :, :])
            lengths_list.append(lengths[i])
            reverse_indices[i] = j

            for state_list, state in zip(states_lists, states):
                state_list.append(state[:, i, :].unsqueeze(1))

        reverse_indices = list(reverse_indices)

        s_inputs = torch.stack(s_inputs_list, dim=0)
        packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list, batch_first=batch_first)

        r_states = tuple(torch.cat(state_list, dim=1) for state_list in states_lists)
        if len(r_states) == 1:
            r_states = r_states[0]

        return packed_seq, reverse_indices, r_states