in ax/plot/diagnostic.py [0:0]
def interact_empirical_model_validation(batch: BatchTrial, data: Data) -> AxPlotConfig:
"""Compare the model predictions for the batch arms against observed data.
Relies on the model predictions stored on the generator_runs of batch.
Args:
batch: Batch on which to perform analysis.
data: Observed data for the batch.
Returns:
AxPlotConfig for the plot.
"""
insample_data: Dict[str, PlotInSampleArm] = {}
metric_names = list(data.df["metric_name"].unique())
for struct in batch.generator_run_structs:
generator_run = struct.generator_run
if generator_run.model_predictions is None:
continue
for i, arm in enumerate(generator_run.arms):
arm_data = {
"name": arm.name_or_short_signature,
"y": {},
"se": {},
"parameters": arm.parameters,
"y_hat": {},
"se_hat": {},
"context_stratum": None,
}
predictions = generator_run.model_predictions
for _, row in data.df[
data.df["arm_name"] == arm.name_or_short_signature
].iterrows():
metric_name = row["metric_name"]
# pyre-fixme[16]: Optional type has no attribute `__setitem__`.
arm_data["y"][metric_name] = row["mean"]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
# str]` has no attribute `__setitem__`.
arm_data["se"][metric_name] = row["sem"]
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
arm_data["y_hat"][metric_name] = predictions[0][metric_name][i]
# pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
# typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
# str]` has no attribute `__setitem__`.
arm_data["se_hat"][metric_name] = predictions[1][metric_name][
metric_name
][i]
# pyre-fixme[6]: Expected `Optional[Dict[str, Union[float, str]]]` for 1s...
insample_data[arm.name_or_short_signature] = PlotInSampleArm(**arm_data)
if not insample_data:
raise ValueError("No model predictions present on the batch.")
plot_data = PlotData(
metrics=metric_names,
in_sample=insample_data,
out_of_sample=None,
status_quo_name=None,
)
fig = _obs_vs_pred_dropdown_plot(data=plot_data, rel=False)
fig["layout"]["title"] = "Cross-validation"
return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)