in multiple_futures_prediction/model_ngsim.py [0:0]
def rbf_state_enc_get_attens(self, nbrs_enc: torch.Tensor, ref_pos: torch.Tensor, nbrs_info_this: List ) -> List[torch.Tensor]:
"""Computing the attention over other agents.
Args:
nbrs_info_this is a list of list of (nbr_batch_ind, nbr_id, nbr_ctx_ind)
Returns:
attention weights over the neighbors.
"""
assert len(nbrs_info_this) == ref_pos.shape[0]
if self.extra_pos_dim > 0:
pos_enc = torch.zeros(nbrs_enc.shape[0],2, device=nbrs_enc.device)
counter = 0
for n in range(len(nbrs_info_this)):
for nbr in nbrs_info_this[n]:
pos_enc[counter,:] = ref_pos[nbr[0],:] - ref_pos[n,:]
counter += 1
Key = self.sec_key_net( torch.cat( (nbrs_enc,pos_enc),dim=1) )
# e.g. num_agents by self.sec_key_dim
else:
Key = self.sec_key_net( nbrs_enc ) # e.g. num_agents by self.sec_key_dim
attens0 = []
for slot in self.slot_keys:
attens0.append( torch.exp( -self.scale*(Key-torch.t(slot)).norm(dim=1)) )
Atten = torch.stack(attens0, dim=0) # e.g. num_keys x num_agents
attens = []
counter = 0
for n in range(len(nbrs_info_this)):
list_of_nbrs = nbrs_info_this[n]
counter2 = counter+len(list_of_nbrs)
attens.append( Atten[:, counter:counter2 ] )
counter = counter2
return attens