in datasets.py [0:0]
def spatiotemporal_events_collate_fn(data):
"""Input is a list of tensors with shape (T, 1 + D)
where T may be different for each tensor.
Returns:
event_times: (N, max_T)
spatial_locations: (N, max_T, D)
mask: (N, max_T)
"""
if len(data) == 0:
# Dummy batch, sometimes this occurs when using multi-GPU.
return torch.zeros(1, 1), torch.zeros(1, 1, 2), torch.zeros(1, 1)
dim = data[0].shape[1]
lengths = [seq.shape[0] for seq in data]
max_len = max(lengths)
padded_seqs = [torch.cat([s, torch.zeros(max_len - s.shape[0], dim).to(s)], 0) if s.shape[0] != max_len else s for s in data]
data = torch.stack(padded_seqs, dim=0)
event_times = data[:, :, 0]
spatial_locations = data[:, :, 1:]
mask = torch.stack([torch.cat([torch.ones(seq_len), torch.zeros(max_len - seq_len)], dim=0) for seq_len in lengths])
return event_times, spatial_locations, mask