in multiple_futures_prediction/train_ngsim.py [0:0]
def get_mean( train_data_loader: DataLoader, batches: Optional[int]=200 ) -> np.ndarray:
"""Compute the means over some samples from the training data."""
yy = []
counters = None
for i, data in enumerate(train_data_loader):
if i > batches: # type: ignore
break
hist, nbrs, _, fut, fut_mask, _, _ = data
target = fut.cpu().numpy()
valid = fut_mask.cpu().numpy().sum(axis=1)
if counters is None:
counters = np.zeros_like( valid )
counters += valid
isinvalid = (fut_mask == 0)
target[isinvalid] = 0
yy.append( target )
Y = np.concatenate(yy, axis=1)
y_mean= np.divide( np.sum(Y,axis=1), counters)
return y_mean