def forward_mfp()

in multiple_futures_prediction/model_ngsim.py [0:0]


  def forward_mfp(self, hist:torch.Tensor, nbrs:torch.Tensor, masks:torch.Tensor, context:Any, 
                  nbrs_info:List, fut:torch.Tensor, bStepByStep:bool, 
                  use_forcing:Optional[Union[None,int]]=None) -> Tuple[List[torch.Tensor], Any]:
    """Forward propagation function for the MFP
    
    Computes dynamic state encoding with precomputed attention tensor and the 
    RNN based encoding.
    Args:
      hist: Trajectory history.
      nbrs: Neighbors.
      masks: Neighbors mask.
      context: contextual information in image form (if used).
      nbrs_info: information as to which other agents are neighbors.
      fut: Future Trajectory.
      bStepByStep: During rollout, interactive or independent.
      use_forcing: Teacher-forcing or classmate forcing.

    Returns:
      fut_pred: a list of predictions, one for each mode.
      modes_pred: prediction over latent modes.    
    """
    use_forcing = self.use_forcing if use_forcing==None else use_forcing

    # Normalize to reference position.
    ref_pos = hist[-1,:,:]
    hist = hist - ref_pos.view(1,-1,2)
    
    # Encode history trajectories.
    if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU):
      _, hist_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
    else:
      _,(hist_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist))) #hist torch.Size([16, 128, 2])

    if self.use_gru:
      hist_enc = hist_enc.permute(1,0,2).contiguous()
      hist_enc = self.leaky_relu(self.dyn_emb( hist_enc.view(hist_enc.shape[0], -1) ))
    else:
      hist_enc = self.leaky_relu(self.dyn_emb(hist_enc.view(hist_enc.shape[1],hist_enc.shape[2]))) #torch.Size([128, 32])

    num_nbrs = sum([len(nbs) for nb_id, nbs in nbrs_info[0].items() ])      
    if num_nbrs > 0:
      nbrs_ref_pos = nbrs[-1,:,:]
      nbrs = nbrs - nbrs_ref_pos.view(1,-1,2) # normalize

      # Forward pass for all neighbors.
      if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU):
        _, nbrs_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs)))
      else:
        _, (nbrs_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs)))

      if self.use_gru:
        nbrs_enc = nbrs_enc.permute(1,0,2).contiguous()
        nbrs_enc = nbrs_enc.view(nbrs_enc.shape[0], -1)
      else:
        nbrs_enc = nbrs_enc.view(nbrs_enc.shape[1], nbrs_enc.shape[2])
  
      attens = self.rbf_state_enc_get_attens(nbrs_enc, ref_pos, nbrs_info[0])            
      nbr_atten_enc = self.rbf_state_enc_hist_fwd(attens, nbrs_enc, nbrs_info[0])

    else: # if have no neighbors
      attens = None # type: ignore
      nbr_atten_enc = torch.zeros( 1, self.nbr_atten_embedding_size, dtype=torch.float32, device=masks.device )

    if self.use_context: #context encoding
      context_enc = self.relu(self.context_conv( context ))        
      context_enc = self.context_maxpool( self.context_conv2( context_enc ))
      context_enc = self.relu(self.context_conv3(context_enc))            
      context_enc = self.context_fc( context_enc.view( context_enc.shape[0], -1) )
      
      enc = torch.cat((nbr_atten_enc, hist_enc, context_enc),1)
    else:
      enc = torch.cat((nbr_atten_enc, hist_enc),1)
    # e.g. nbr_atten_enc: [num_agents by 80], hist_enc: [num_agents by 32], enc would be [num_agents by 112]
    
    ######################################################################################################      
    modes_pred = None if self.modes==1 else self.softmax(self.op_modes(enc))
    fut_pred = self.decode(enc, attens, nbrs_info[0], ref_pos, fut, bStepByStep, use_forcing)      
    return fut_pred, modes_pred