def rbf_state_enc_hist_fwd()

in multiple_futures_prediction/model_ngsim.py [0:0]


  def rbf_state_enc_hist_fwd(self, attens: List, nbrs_enc: torch.Tensor, nbrs_info_this: List) -> torch.Tensor:
    """Computes dynamic state encoding.    
    Computes dynica state encoding with precomputed attention tensor and the 
    RNN based encoding.
    Args:    
      attens is a list of [ [slots x num_neighbors]]
      nbrs_enc is num_agents by input_dim      
    Returns:
      feature vector
    """    
    out = []
    counter = 0
    for n in range(len(nbrs_info_this)):
      list_of_nbrs = nbrs_info_this[n]        
      if len(list_of_nbrs) > 0:        
        counter2 = counter+len(list_of_nbrs)
        nbr_feat = nbrs_enc[counter:counter2,:]
        out.append( torch.mm( attens[n], nbr_feat ) )
        counter = counter2
      else:
        out.append( torch.zeros(self.num_slots, nbrs_enc.shape[1] ).to(nbrs_enc.device) )  
        # if no neighbors found, use all zeros.
    st_enc = torch.stack(out, dim=0).view(len(out),-1) # num_agents by slots*enc dim
    return self.sec_hist_net(st_enc)