in multiple_futures_prediction/my_utils.py [0:0]
def nll_loss_test_multimodes(pred: List[torch.Tensor], data: torch.Tensor, mask: torch.Tensor, modes_pred: torch.Tensor, y_mean: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor] :
"""NLL loss multimodes for test time."""
modes = len(pred)
nSteps, batch_sz, dim = pred[0].shape
total = torch.zeros(mask.shape[0],mask.shape[1], modes).to(y_mean.device)
count = 0
for k in range(modes):
wts = modes_pred[:,k]
wts = wts.repeat(nSteps,1)
y_pred = pred[k]
if y_mean is not None:
x_pred_mean = y_pred[:, :, 0]+y_mean[:,0].view(-1,1)
y_pred_mean = y_pred[:, :, 1]+y_mean[:,1].view(-1,1)
else:
x_pred_mean = y_pred[:, :, 0]
y_pred_mean = y_pred[:, :, 1]
x_sigma = y_pred[:, :, 2]
y_sigma = y_pred[:, :, 3]
rho = y_pred[:, :, 4]
ohr = torch.pow(1 - torch.pow(rho, 2), -0.5) # type: ignore
x = data[:, :, 0]
y = data[:, :, 1]
out = -(torch.pow(ohr, 2) * (torch.pow(x_sigma, 2) * torch.pow(x - x_pred_mean, 2) + torch.pow(y_sigma, 2) * torch.pow(y - y_pred_mean,2)
-2 * rho * torch.pow(x_sigma, 1) * torch.pow(y_sigma, 1) * (x - x_pred_mean) * (y - y_pred_mean)) - torch.log(x_sigma * y_sigma * ohr))
total[:, :, count] = out + torch.log(wts)
count += 1
total = -logsumexp(total,dim = 2)
total = total * mask[:,:,0]
lossVal = torch.sum(total,dim=1)
counts = torch.sum(mask[:,:,0],dim=1)
return lossVal, counts