in multiple_futures_prediction/model_ngsim.py [0:0]
def decode(self, enc: torch.Tensor, attens:List, nbrs_info_this:List, ref_pos:torch.Tensor, fut:torch.Tensor, bStepByStep:bool, use_forcing:Any ) -> List[torch.Tensor]:
"""Decode the future trajectory using RNNs.
Given computed feature vector, decode the future with multimodes, using
dynamic attention and either interactive or non-interactive rollouts.
Args:
enc: encoded features, one per agent.
attens: attentional weights, list of objs, each with dimenstion of [8 x 4] (e.g.)
nbrs_info_this: information on who are the neighbors
ref_pos: the current postion (reference position) of the agents.
fut: future trajectory (only useful for teacher or classmate forcing)
bStepByStep: interactive or non-interactive rollout
use_forcing: 0: None. 1: Teacher-forcing. 2: classmate forcing.
Returns:
fut_pred: a list of predictions, one for each mode.
modes_pred: prediction over latent modes.
"""
if not bStepByStep: # Non-interactive rollouts
enc = enc.repeat(self.out_length, 1, 1)
pos_enc = torch.zeros( self.out_length, enc.shape[1], self.posi_enc_dim+self.posi_enc_ego_dim, device=enc.device )
enc2 = torch.cat( (enc, pos_enc), dim=2)
fut_preds = []
for k in range(self.modes):
h_dec, _ = self.dec_lstm[k](enc2)
h_dec = h_dec.permute(1, 0, 2)
fut_pred = self.op[k](h_dec)
fut_pred = fut_pred.permute(1, 0, 2) #torch.Size([nSteps, num_agents, 5])
fut_pred = Gaussian2d(fut_pred)
fut_preds.append(fut_pred)
return fut_preds
else:
batch_sz = enc.shape[0]
inds = []
chunks = []
for n in range(len(nbrs_info_this)):
chunks.append( len(nbrs_info_this[n]) )
for nbr in nbrs_info_this[n]:
inds.append(nbr[0])
flat_index = torch.LongTensor(inds).to(ref_pos.device) # type: ignore
fut_preds = []
for k in range(self.modes):
direc = 2 if self.bi_direc else 1
hidden = torch.zeros(self.num_layers*direc, batch_sz, self.decoder_size).to(fut.device)
preds: List[torch.Tensor] = []
for t in range(self.out_length):
if t == 0: # Intial timestep.
if use_forcing == 0:
pred_fut_t = torch.zeros_like(fut[t,:,:])
ego_fut_t = pred_fut_t
elif use_forcing == 1:
pred_fut_t = fut[t,:,:]
ego_fut_t = pred_fut_t
else:
pred_fut_t = fut[t,:,:]
ego_fut_t = torch.zeros_like(pred_fut_t)
else:
if use_forcing == 0:
pred_fut_t = preds[-1][:,:,:2].squeeze()
ego_fut_t = pred_fut_t
elif use_forcing == 1:
pred_fut_t = fut[t,:,:]
ego_fut_t = pred_fut_t
else:
pred_fut_t = fut[t,:,:]
ego_fut_t = preds[-1][:,:,:2]
if attens == None:
pos_enc = torch.zeros(batch_sz, self.posi_enc_dim, device=enc.device )
else:
pos_enc = self.rbf_state_enc_pos_fwd(attens, ref_pos, pred_fut_t, flat_index, chunks )
enc_large = torch.cat( ( enc.view(1,enc.shape[0],enc.shape[1]),
pos_enc.view(1,batch_sz, self.posi_enc_dim),
ego_fut_t.view(1, batch_sz, self.posi_enc_ego_dim ) ), dim=2 )
out, hidden = self.dec_lstm[k]( enc_large, hidden)
pred = Gaussian2d(self.op[k](out))
preds.append( pred )
fut_pred_k = torch.cat(preds,dim=0)
fut_preds.append(fut_pred_k)
return fut_preds