def reverse_indice_for_state()

in src/flint/torch_util.py [0:0]


def reverse_indice_for_state(states, reverse_indices):
    """
    :param states: [rnn.num_layers, batch_size, rnn.hidden_size]
    :param reverse_indices: [batch_size]
    :return:
    """
    if states is None:
        states = ()
    elif not isinstance(states, tuple):
        states = (states,)  # rnn.num_layers, batch_size, rnn.hidden_size

    states_lists = tuple([] for _ in states)
    for i in reverse_indices:
        for state_list, state in zip(states_lists, states):
            state_list.append(state[:, i, :].unsqueeze(1))

    r_states = tuple(torch.cat(state_list, dim=1) for state_list in states_lists)
    if len(r_states) == 1:
        r_states = r_states[0]
    return r_states