in src/flint/torch_util.py [0:0]
def pack_for_rnn_seq(inputs, lengths, batch_first=True, states=None):
"""
:param states: [rnn.num_layers, batch_size, rnn.hidden_size]
:param inputs: Shape of the input should be [B, T, D] if batch_first else [T, B, D].
:param lengths: [B]
:param batch_first:
:return:
"""
if not batch_first:
_, sorted_indices = lengths.sort()
'''
Reverse to decreasing order
'''
r_index = reversed(list(sorted_indices))
s_inputs_list = []
lengths_list = []
reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)
for j, i in enumerate(r_index):
s_inputs_list.append(inputs[:, i, :].unsqueeze(1))
lengths_list.append(lengths[i])
reverse_indices[i] = j
reverse_indices = list(reverse_indices)
s_inputs = torch.cat(s_inputs_list, 1)
packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list)
return packed_seq, reverse_indices
else:
_, sorted_indices = lengths.sort()
'''
Reverse to decreasing order
'''
r_index = reversed(list(sorted_indices))
s_inputs_list = []
lengths_list = []
reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)
if states is None:
states = ()
elif not isinstance(states, tuple):
states = (states,) # rnn.num_layers, batch_size, rnn.hidden_size
states_lists = tuple([] for _ in states)
for j, i in enumerate(r_index):
s_inputs_list.append(inputs[i, :, :])
lengths_list.append(lengths[i])
reverse_indices[i] = j
for state_list, state in zip(states_lists, states):
state_list.append(state[:, i, :].unsqueeze(1))
reverse_indices = list(reverse_indices)
s_inputs = torch.stack(s_inputs_list, dim=0)
packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list, batch_first=batch_first)
r_states = tuple(torch.cat(state_list, dim=1) for state_list in states_lists)
if len(r_states) == 1:
r_states = r_states[0]
return packed_seq, reverse_indices, r_states