in gala/model.py [0:0]
def _forward_gru(self, x, hxs, masks):
if x.size(0) == hxs.size(0):
x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0))
x = x.squeeze(0)
hxs = hxs.squeeze(0)
else:
# x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
N = hxs.size(0)
T = int(x.size(0) / N)
# unflatten
x = x.view(T, N, x.size(1))
# Same deal with masks
masks = masks.view(T, N)
# Let's figure out which steps in the sequence have a zero for any agent
# We will always assume t=0 has a zero in it as that makes the logic cleaner
has_zeros = ((masks[1:] == 0.0) \
.any(dim=-1)
.nonzero()
.squeeze()
.cpu())
# +1 to correct the masks[1:]
if has_zeros.dim() == 0:
# Deal with scalar
has_zeros = [has_zeros.item() + 1]
else:
has_zeros = (has_zeros + 1).numpy().tolist()
# add t=0 and t=T to the list
has_zeros = [0] + has_zeros + [T]
hxs = hxs.unsqueeze(0)
outputs = []
for i in range(len(has_zeros) - 1):
# We can now process steps that don't have any zeros in masks together!
# This is much faster
start_idx = has_zeros[i]
end_idx = has_zeros[i + 1]
rnn_scores, hxs = self.gru(
x[start_idx:end_idx],
hxs * masks[start_idx].view(1, -1, 1))
outputs.append(rnn_scores)
# assert len(outputs) == T
# x is a (T, N, -1) tensor
x = torch.cat(outputs, dim=0)
# flatten
x = x.view(T * N, -1)
hxs = hxs.squeeze(0)
return x, hxs