in multiple_futures_prediction/model_ngsim.py [0:0]
def rbf_state_enc_hist_fwd(self, attens: List, nbrs_enc: torch.Tensor, nbrs_info_this: List) -> torch.Tensor:
"""Computes dynamic state encoding.
Computes dynica state encoding with precomputed attention tensor and the
RNN based encoding.
Args:
attens is a list of [ [slots x num_neighbors]]
nbrs_enc is num_agents by input_dim
Returns:
feature vector
"""
out = []
counter = 0
for n in range(len(nbrs_info_this)):
list_of_nbrs = nbrs_info_this[n]
if len(list_of_nbrs) > 0:
counter2 = counter+len(list_of_nbrs)
nbr_feat = nbrs_enc[counter:counter2,:]
out.append( torch.mm( attens[n], nbr_feat ) )
counter = counter2
else:
out.append( torch.zeros(self.num_slots, nbrs_enc.shape[1] ).to(nbrs_enc.device) )
# if no neighbors found, use all zeros.
st_enc = torch.stack(out, dim=0).view(len(out),-1) # num_agents by slots*enc dim
return self.sec_hist_net(st_enc)