in multiple_futures_prediction/dataset_ngsim.py [0:0]
def collate_fn(self, samples: List[Any]) -> Tuple[Any,Any,Any,Any,Any,Union[Any,None],Any] :
"""Prepare a batch suitable for MFP training."""
nbr_batch_size = 0
num_samples = 0
for _,_,nbrs,im_crop in samples:
nbr_batch_size += sum([len(nbr) for nbr in nbrs.values() ])
num_samples += len(nbrs)
maxlen = self.t_h//self.d_s + 1
if nbr_batch_size <= 0:
nbrs_batch = torch.zeros(maxlen,1,2)
else:
nbrs_batch = torch.zeros(maxlen,nbr_batch_size,2)
pos = [0, 0]
nbr_inds_batch = torch.zeros( num_samples, self.grid_size[1],self.grid_size[0], self.enc_size*self.enc_fac)
nbr_inds_batch = nbr_inds_batch.byte()
hist_batch = torch.zeros(maxlen, num_samples, 2) #e.g. (31, 41, 2)
fut_batch = torch.zeros(self.t_f//self.d_s, num_samples, 2)
mask_batch = torch.zeros(self.t_f//self.d_s, num_samples, 2)
if self.use_context:
context_batch = torch.zeros(num_samples, im_crop.shape[0], im_crop.shape[1], im_crop.shape[2] )
else:
context_batch: Union[None, torch.Tensor] = None # type: ignore
nbrs_infos = []
count = 0
samples_so_far = 0
for sampleId,(hist, fut, nbrs, context) in enumerate(samples):
num = len(nbrs)
for j in range(num):
hist_batch[0:len(hist[j]), samples_so_far+j, :] = torch.from_numpy(hist[j])
fut_batch[0:len(fut[j]), samples_so_far+j, :] = torch.from_numpy(fut[j])
mask_batch[0:len(fut[j]),samples_so_far+j,:] = 1
samples_so_far += num
nbrs_infos.append(nbrs)
if self.use_context:
context_batch[sampleId,:,:,:] = torch.from_numpy(context)
# nbrs is a dictionary of key to list of nbr (batch_index, veh_id, grid_ind)
for batch_ind, list_of_nbr in nbrs.items():
for batch_id, vehid, grid_ind in list_of_nbr:
if batch_id >= 0:
nbr_hist = hist[batch_id]
nbrs_batch[0:len(nbr_hist),count,:] = torch.from_numpy( nbr_hist )
pos[0] = grid_ind % self.grid_size[0]
pos[1] = grid_ind // self.grid_size[0]
nbr_inds_batch[batch_ind,pos[1],pos[0],:] = torch.ones(self.enc_size*self.enc_fac).byte()
count+=1
return (hist_batch, nbrs_batch, nbr_inds_batch, fut_batch, mask_batch, context_batch, nbrs_infos)