in ax/modelbridge/base.py [0:0]
def unwrap_observation_data(observation_data: List[ObservationData]) -> TModelPredict:
"""Converts observation data to the format for model prediction outputs.
That format assumes each observation data has the same set of metrics.
"""
metrics = set(observation_data[0].metric_names)
f: TModelMean = {metric: [] for metric in metrics}
cov: TModelCov = {m1: {m2: [] for m2 in metrics} for m1 in metrics}
for od in observation_data:
if set(od.metric_names) != metrics:
raise ValueError(
"Each ObservationData should use same set of metrics. "
"Expected {exp}, got {got}.".format(
exp=metrics, got=set(od.metric_names)
)
)
for i, m1 in enumerate(od.metric_names):
f[m1].append(od.means[i])
for j, m2 in enumerate(od.metric_names):
cov[m1][m2].append(od.covariance[i, j])
return f, cov