in rl/models/rnn_state_encoder.py [0:0]
def seq_forward(self, x, hidden_states, masks):
r"""Forward for a sequence of length T
Args:
x: (T, N, -1) Tensor that has been flattened to (T * N, -1)
hidden_states: The starting hidden state.
masks: The masks to be applied to hidden state at every timestep.
A (T, N) tensor flatten to (T * N)
"""
# x is a (T, N, -1) tensor flattened to (T * N, -1)
n = hidden_states.size(1)
t = int(x.size(0) / n)
# unflatten
x = x.view(t, n, x.size(1))
masks = masks.view(t, n)
# steps in sequence which have zero for any agent. Assume t=0 has
# a zero in it.
has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero(as_tuple=False).squeeze().cpu()
# +1 to correct the masks[1:]
if has_zeros.dim() == 0:
has_zeros = [has_zeros.item() + 1] # handle scalar
else:
has_zeros = (has_zeros + 1).numpy().tolist()
# add t=0 and t=T to the list
has_zeros = [0] + has_zeros + [t]
hidden_states = self._unpack_hidden(hidden_states)
outputs = []
for i in range(len(has_zeros) - 1):
# process steps that don't have any zeros in masks together
start_idx = has_zeros[i]
end_idx = has_zeros[i + 1]
rnn_scores, hidden_states = self.rnn(
x[start_idx:end_idx],
self._mask_hidden(
hidden_states, masks[start_idx].view(1, -1, 1)
),
)
outputs.append(rnn_scores)
# x is a (T, N, -1) tensor
x = torch.cat(outputs, dim=0)
x = x.view(t * n, -1) # flatten
hidden_states = self._pack_hidden(hidden_states)
return x, hidden_states