def collate_fn()

in multiple_futures_prediction/dataset_ngsim.py [0:0]


  def collate_fn(self, samples: List[Any]) -> Tuple[Any,Any,Any,Any,Any,Union[Any,None],Any] :
    """Prepare a batch suitable for MFP training."""
    nbr_batch_size = 0
    num_samples = 0
    for _,_,nbrs,im_crop in samples:
      nbr_batch_size +=  sum([len(nbr) for nbr in nbrs.values() ])      
      num_samples += len(nbrs)

    maxlen = self.t_h//self.d_s + 1
    if nbr_batch_size <= 0:      
      nbrs_batch = torch.zeros(maxlen,1,2)
    else:
      nbrs_batch = torch.zeros(maxlen,nbr_batch_size,2)
    
    pos = [0, 0]
    nbr_inds_batch = torch.zeros( num_samples, self.grid_size[1],self.grid_size[0], self.enc_size*self.enc_fac)
    nbr_inds_batch = nbr_inds_batch.byte()

    hist_batch = torch.zeros(maxlen, num_samples, 2)  #e.g. (31, 41, 2)
    fut_batch       = torch.zeros(self.t_f//self.d_s, num_samples, 2)
    mask_batch   = torch.zeros(self.t_f//self.d_s, num_samples, 2)    
    if self.use_context:
      context_batch = torch.zeros(num_samples, im_crop.shape[0], im_crop.shape[1], im_crop.shape[2] )
    else:
      context_batch: Union[None, torch.Tensor] = None # type: ignore

    nbrs_infos = []
    count = 0
    samples_so_far = 0
    for sampleId,(hist, fut, nbrs, context) in enumerate(samples):            
      num = len(nbrs)      
      for j in range(num):
        hist_batch[0:len(hist[j]), samples_so_far+j, :] = torch.from_numpy(hist[j])
        fut_batch[0:len(fut[j]), samples_so_far+j, :] = torch.from_numpy(fut[j])
        mask_batch[0:len(fut[j]),samples_so_far+j,:] = 1                
      samples_so_far += num

      nbrs_infos.append(nbrs)

      if self.use_context:
        context_batch[sampleId,:,:,:] = torch.from_numpy(context)                

      # nbrs is a dictionary of key to list of nbr (batch_index, veh_id, grid_ind)
      for batch_ind, list_of_nbr in nbrs.items():
        for batch_id, vehid, grid_ind in list_of_nbr:          
          if batch_id >= 0:
            nbr_hist = hist[batch_id]                                    
            nbrs_batch[0:len(nbr_hist),count,:] = torch.from_numpy( nbr_hist )
            pos[0] = grid_ind % self.grid_size[0]
            pos[1] = grid_ind // self.grid_size[0]
            nbr_inds_batch[batch_ind,pos[1],pos[0],:] = torch.ones(self.enc_size*self.enc_fac).byte()
            count+=1

    return (hist_batch, nbrs_batch, nbr_inds_batch, fut_batch, mask_batch, context_batch, nbrs_infos)