def nll_loss_test_multimodes()

in multiple_futures_prediction/my_utils.py [0:0]


def nll_loss_test_multimodes(pred: List[torch.Tensor], data: torch.Tensor, mask: torch.Tensor, modes_pred: torch.Tensor, y_mean: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor] :
  """NLL loss multimodes for test time."""  
  modes = len(pred)
  nSteps, batch_sz, dim = pred[0].shape
  total = torch.zeros(mask.shape[0],mask.shape[1], modes).to(y_mean.device)
  count = 0
  for k in range(modes):        
    wts = modes_pred[:,k]
    wts = wts.repeat(nSteps,1)
    y_pred = pred[k]    
    if y_mean is not None:
      x_pred_mean = y_pred[:, :, 0]+y_mean[:,0].view(-1,1) 
      y_pred_mean = y_pred[:, :, 1]+y_mean[:,1].view(-1,1) 
    else:
      x_pred_mean = y_pred[:, :, 0]
      y_pred_mean = y_pred[:, :, 1]
    x_sigma = y_pred[:, :, 2]
    y_sigma = y_pred[:, :, 3]
    rho = y_pred[:, :, 4]
    ohr = torch.pow(1 - torch.pow(rho, 2), -0.5)  # type: ignore
    x = data[:, :, 0]
    y = data[:, :, 1]
    out = -(torch.pow(ohr, 2) * (torch.pow(x_sigma, 2) * torch.pow(x - x_pred_mean, 2) + torch.pow(y_sigma, 2) * torch.pow(y - y_pred_mean,2) 
          -2 * rho * torch.pow(x_sigma, 1) * torch.pow(y_sigma, 1) * (x - x_pred_mean) * (y - y_pred_mean)) - torch.log(x_sigma * y_sigma * ohr))
    total[:, :, count] =  out + torch.log(wts)
    count += 1
  total = -logsumexp(total,dim = 2)
  total = total * mask[:,:,0]
  lossVal = torch.sum(total,dim=1)
  counts = torch.sum(mask[:,:,0],dim=1)
  return lossVal, counts