def get_mean()

in multiple_futures_prediction/train_ngsim.py [0:0]


def get_mean( train_data_loader: DataLoader, batches: Optional[int]=200 ) -> np.ndarray:
  """Compute the means over some samples from the training data."""  
  yy = []
  counters = None
  for i, data in enumerate(train_data_loader):        
    if i > batches: # type: ignore
      break
    hist, nbrs, _, fut, fut_mask, _, _ = data
    target = fut.cpu().numpy()
    valid = fut_mask.cpu().numpy().sum(axis=1)

    if counters is None:
      counters = np.zeros_like( valid )
    counters += valid

    isinvalid = (fut_mask == 0)        
    target[isinvalid] = 0
    yy.append( target )
  
  Y = np.concatenate(yy, axis=1)
  y_mean= np.divide( np.sum(Y,axis=1), counters)
  return y_mean