in multiple_futures_prediction/train_ngsim.py [0:0]
def eval(metric: str, net: torch.nn.Module, params: AttrDict, data_loader: DataLoader, bStepByStep: bool,
use_forcing: int, y_mean: np.ndarray, num_batches: int, dataset_name: str) -> torch.Tensor:
"""Evaluation function for validation and test data.
Given a MFP network, data loader, evaulate either NLL or RMSE error.
"""
print('eval ', dataset_name)
num = params.fut_len_orig_hz//params.subsampling
lossVals = torch.zeros(num)
counts = torch.zeros(num)
for i, data in enumerate(data_loader):
if i >= num_batches:
break
hist, nbrs, mask, fut, mask, context, nbrs_info = data
if params.use_cuda:
hist = hist.cuda()
nbrs = nbrs.cuda()
mask = mask.cuda()
fut = fut.cuda()
mask = mask.cuda()
if context is not None:
context = context.cuda()
if metric == 'nll':
fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep, use_forcing=use_forcing)
if params.modes == 1:
if params.remove_y_mean:
fut_preds[0][:,:,:2] += y_mean.unsqueeze(1).to(fut.device)
l, c = nll_loss_test(fut_preds[0], fut, mask)
else:
l, c = nll_loss_test_multimodes(fut_preds, fut, mask, modes_pred, y_mean.to(fut.device) )
else: # RMSE error
assert params.modes == 1
fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep, use_forcing=use_forcing)
if params.modes == 1:
if params.remove_y_mean:
fut_preds[0][:,:,:2] += y_mean.unsqueeze(1).to(fut.device)
l, c = mse_loss_test(fut_preds[0], fut, mask)
lossVals += l.detach().cpu()
counts += c.detach().cpu()
if metric == 'nll':
err = lossVals / counts
print(lossVals / counts)
else:
err = torch.pow(lossVals / counts,0.5)*0.3048
print( err ) # Calculate RMSE and convert from feet to meters
return err