def nll_loss_test()

in multiple_futures_prediction/my_utils.py [0:0]


def nll_loss_test( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  """NLL for testing cases, returns a vector over future timesteps."""  
  x_mean = pred[:,:,0]
  y_mean = pred[:,:,1]
  x_sigma = pred[:,:,2]
  y_sigma = pred[:,:,3]
  rho = pred[:,:,4]
  ohr = torch.pow(1-torch.pow(rho,2),-0.5) # type: ignore
  x = data[:,:, 0]; y = data[:,:, 1]
  results = torch.pow(ohr, 2)*(torch.pow(x_sigma, 2)*torch.pow(x-x_mean, 2) + torch.pow(y_sigma, 2)*torch.pow(y-y_mean, 2) 
            -2*rho*torch.pow(x_sigma, 1)*torch.pow(y_sigma, 1)*(x-x_mean)*(y-y_mean)) - torch.log(x_sigma*y_sigma*ohr)
  results = results*mask[:,:,0] # nSteps by nBatch
  assert torch.sum(mask) > 0.0
  counts = torch.sum(mask[:, :, 0], dim=1)  
  return torch.sum(results, dim=1), counts