def rbf_state_enc_get_attens()

in multiple_futures_prediction/model_ngsim.py [0:0]


  def rbf_state_enc_get_attens(self, nbrs_enc: torch.Tensor, ref_pos: torch.Tensor, nbrs_info_this: List ) -> List[torch.Tensor]:
    """Computing the attention over other agents.
    Args:
      nbrs_info_this is a list of list of (nbr_batch_ind, nbr_id, nbr_ctx_ind)
    Returns:
      attention weights over the neighbors.
    """
    assert len(nbrs_info_this) == ref_pos.shape[0]        
    if self.extra_pos_dim > 0:
      pos_enc = torch.zeros(nbrs_enc.shape[0],2, device=nbrs_enc.device)
      counter = 0
      for n in range(len(nbrs_info_this)):
        for nbr in nbrs_info_this[n]:
          pos_enc[counter,:] = ref_pos[nbr[0],:] - ref_pos[n,:]
          counter += 1          
      Key = self.sec_key_net( torch.cat( (nbrs_enc,pos_enc),dim=1) )  
      # e.g. num_agents by self.sec_key_dim
    else:
      Key = self.sec_key_net( nbrs_enc )  # e.g. num_agents by self.sec_key_dim

    attens0 = []        
    for slot in self.slot_keys:            
      attens0.append( torch.exp( -self.scale*(Key-torch.t(slot)).norm(dim=1)) )

    Atten = torch.stack(attens0, dim=0) # e.g. num_keys x num_agents
    attens = []
    counter = 0
    for n in range(len(nbrs_info_this)):
      list_of_nbrs = nbrs_info_this[n]
      counter2 = counter+len(list_of_nbrs)
      attens.append( Atten[:, counter:counter2 ] )        
      counter = counter2
    return attens