in multiple_futures_prediction/model_ngsim.py [0:0]
def forward_mfp(self, hist:torch.Tensor, nbrs:torch.Tensor, masks:torch.Tensor, context:Any,
nbrs_info:List, fut:torch.Tensor, bStepByStep:bool,
use_forcing:Optional[Union[None,int]]=None) -> Tuple[List[torch.Tensor], Any]:
"""Forward propagation function for the MFP
Computes dynamic state encoding with precomputed attention tensor and the
RNN based encoding.
Args:
hist: Trajectory history.
nbrs: Neighbors.
masks: Neighbors mask.
context: contextual information in image form (if used).
nbrs_info: information as to which other agents are neighbors.
fut: Future Trajectory.
bStepByStep: During rollout, interactive or independent.
use_forcing: Teacher-forcing or classmate forcing.
Returns:
fut_pred: a list of predictions, one for each mode.
modes_pred: prediction over latent modes.
"""
use_forcing = self.use_forcing if use_forcing==None else use_forcing
# Normalize to reference position.
ref_pos = hist[-1,:,:]
hist = hist - ref_pos.view(1,-1,2)
# Encode history trajectories.
if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU):
_, hist_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
else:
_,(hist_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist))) #hist torch.Size([16, 128, 2])
if self.use_gru:
hist_enc = hist_enc.permute(1,0,2).contiguous()
hist_enc = self.leaky_relu(self.dyn_emb( hist_enc.view(hist_enc.shape[0], -1) ))
else:
hist_enc = self.leaky_relu(self.dyn_emb(hist_enc.view(hist_enc.shape[1],hist_enc.shape[2]))) #torch.Size([128, 32])
num_nbrs = sum([len(nbs) for nb_id, nbs in nbrs_info[0].items() ])
if num_nbrs > 0:
nbrs_ref_pos = nbrs[-1,:,:]
nbrs = nbrs - nbrs_ref_pos.view(1,-1,2) # normalize
# Forward pass for all neighbors.
if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU):
_, nbrs_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs)))
else:
_, (nbrs_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs)))
if self.use_gru:
nbrs_enc = nbrs_enc.permute(1,0,2).contiguous()
nbrs_enc = nbrs_enc.view(nbrs_enc.shape[0], -1)
else:
nbrs_enc = nbrs_enc.view(nbrs_enc.shape[1], nbrs_enc.shape[2])
attens = self.rbf_state_enc_get_attens(nbrs_enc, ref_pos, nbrs_info[0])
nbr_atten_enc = self.rbf_state_enc_hist_fwd(attens, nbrs_enc, nbrs_info[0])
else: # if have no neighbors
attens = None # type: ignore
nbr_atten_enc = torch.zeros( 1, self.nbr_atten_embedding_size, dtype=torch.float32, device=masks.device )
if self.use_context: #context encoding
context_enc = self.relu(self.context_conv( context ))
context_enc = self.context_maxpool( self.context_conv2( context_enc ))
context_enc = self.relu(self.context_conv3(context_enc))
context_enc = self.context_fc( context_enc.view( context_enc.shape[0], -1) )
enc = torch.cat((nbr_atten_enc, hist_enc, context_enc),1)
else:
enc = torch.cat((nbr_atten_enc, hist_enc),1)
# e.g. nbr_atten_enc: [num_agents by 80], hist_enc: [num_agents by 32], enc would be [num_agents by 112]
######################################################################################################
modes_pred = None if self.modes==1 else self.softmax(self.op_modes(enc))
fut_pred = self.decode(enc, attens, nbrs_info[0], ref_pos, fut, bStepByStep, use_forcing)
return fut_pred, modes_pred