in src/flint/torch_util.py [0:0]
def reverse_indice_for_state(states, reverse_indices):
"""
:param states: [rnn.num_layers, batch_size, rnn.hidden_size]
:param reverse_indices: [batch_size]
:return:
"""
if states is None:
states = ()
elif not isinstance(states, tuple):
states = (states,) # rnn.num_layers, batch_size, rnn.hidden_size
states_lists = tuple([] for _ in states)
for i in reverse_indices:
for state_list, state in zip(states_lists, states):
state_list.append(state[:, i, :].unsqueeze(1))
r_states = tuple(torch.cat(state_list, dim=1) for state_list in states_lists)
if len(r_states) == 1:
r_states = r_states[0]
return r_states