in multiple_futures_prediction/my_utils.py [0:0]
def nll_loss_multimodes(pred: List[torch.Tensor], data: torch.Tensor, mask: torch.Tensor, modes_pred: torch.Tensor, noise: Optional[float]=0.0 ) -> float:
"""NLL loss multimodes for training.
Args:
pred is a list (with N modes) of predictions
data is ground truth
noise is optional
"""
modes = len(pred)
nSteps, batch_sz, dim = pred[0].shape
log_lik = np.zeros( (batch_sz, modes) )
with torch.no_grad():
for kk in range(modes):
nll = nll_loss_per_sample(pred[kk], data, mask)
log_lik[:,kk] = -nll.cpu().numpy()
priors = modes_pred.detach().cpu().numpy()
log_posterior_unnorm = log_lik + np.log(priors).reshape((-1, modes)) #[TotalObjs, net.modes]
log_posterior_unnorm += np.random.randn( *log_posterior_unnorm.shape)*noise
log_posterior = log_posterior_unnorm - special.logsumexp( log_posterior_unnorm, axis=1 ).reshape((batch_sz, 1))
post_pr = np.exp(log_posterior) #[TotalObjs, net.modes]
post_pr = torch.tensor(post_pr).float().to(data.device)
loss = 0.0
for kk in range(modes):
nll_k = nll_loss_per_sample(pred[kk], data, mask)*post_pr[:,kk]
loss += nll_k.sum()/float(batch_sz)
kl_loss = torch.nn.KLDivLoss(reduction='batchmean') #type: ignore
loss += kl_loss( torch.log(modes_pred), post_pr)
return loss