def decode()

in multiple_futures_prediction/model_ngsim.py [0:0]


  def decode(self, enc: torch.Tensor, attens:List, nbrs_info_this:List, ref_pos:torch.Tensor, fut:torch.Tensor, bStepByStep:bool, use_forcing:Any ) -> List[torch.Tensor]:    
    """Decode the future trajectory using RNNs.
    
    Given computed feature vector, decode the future with multimodes, using
    dynamic attention and either interactive or non-interactive rollouts.
    Args:
      enc: encoded features, one per agent.
      attens: attentional weights, list of objs, each with dimenstion of [8 x 4] (e.g.)
      nbrs_info_this: information on who are the neighbors
      ref_pos: the current postion (reference position) of the agents.
      fut: future trajectory (only useful for teacher or classmate forcing)
      bStepByStep: interactive or non-interactive rollout
      use_forcing: 0: None. 1: Teacher-forcing. 2: classmate forcing.

    Returns:
      fut_pred: a list of predictions, one for each mode.
      modes_pred: prediction over latent modes.    
    """
    if not bStepByStep: # Non-interactive rollouts
      enc = enc.repeat(self.out_length, 1, 1)
      pos_enc = torch.zeros( self.out_length, enc.shape[1], self.posi_enc_dim+self.posi_enc_ego_dim, device=enc.device )
      enc2 = torch.cat( (enc, pos_enc), dim=2)                
      fut_preds = []
      for k in range(self.modes):
        h_dec, _ = self.dec_lstm[k](enc2)
        h_dec = h_dec.permute(1, 0, 2)
        fut_pred = self.op[k](h_dec)
        fut_pred = fut_pred.permute(1, 0, 2) #torch.Size([nSteps, num_agents, 5])

        fut_pred = Gaussian2d(fut_pred)
        fut_preds.append(fut_pred)            
      return fut_preds      
    else:
      batch_sz =  enc.shape[0]
      inds = []
      chunks = []
      for n in range(len(nbrs_info_this)):                  
        chunks.append( len(nbrs_info_this[n]) )
        for nbr in nbrs_info_this[n]:
          inds.append(nbr[0])
      flat_index = torch.LongTensor(inds).to(ref_pos.device) # type: ignore 
      
      fut_preds = []
      for k in range(self.modes):
        direc = 2 if self.bi_direc else 1
        hidden = torch.zeros(self.num_layers*direc, batch_sz, self.decoder_size).to(fut.device)
        preds: List[torch.Tensor] = []
        for t in range(self.out_length):
          if t == 0: # Intial timestep.
            if use_forcing == 0:                          
              pred_fut_t =  torch.zeros_like(fut[t,:,:])
              ego_fut_t = pred_fut_t
            elif use_forcing == 1:
              pred_fut_t = fut[t,:,:]
              ego_fut_t = pred_fut_t
            else:
              pred_fut_t = fut[t,:,:]
              ego_fut_t =  torch.zeros_like(pred_fut_t)
          else:
            if use_forcing == 0:
              pred_fut_t = preds[-1][:,:,:2].squeeze()
              ego_fut_t = pred_fut_t
            elif use_forcing == 1:
              pred_fut_t = fut[t,:,:]
              ego_fut_t = pred_fut_t
            else:
              pred_fut_t = fut[t,:,:]
              ego_fut_t = preds[-1][:,:,:2]

          if attens == None:
            pos_enc =  torch.zeros(batch_sz, self.posi_enc_dim, device=enc.device )
          else:
            pos_enc = self.rbf_state_enc_pos_fwd(attens, ref_pos, pred_fut_t, flat_index, chunks )
          
          enc_large = torch.cat( ( enc.view(1,enc.shape[0],enc.shape[1]), 
                                   pos_enc.view(1,batch_sz, self.posi_enc_dim),
                                   ego_fut_t.view(1, batch_sz, self.posi_enc_ego_dim ) ), dim=2 )

          out, hidden = self.dec_lstm[k]( enc_large, hidden)
          pred = Gaussian2d(self.op[k](out))
          preds.append( pred )
        fut_pred_k = torch.cat(preds,dim=0)
        fut_preds.append(fut_pred_k)
      return fut_preds