def interact_empirical_model_validation()

in ax/plot/diagnostic.py [0:0]


def interact_empirical_model_validation(batch: BatchTrial, data: Data) -> AxPlotConfig:
    """Compare the model predictions for the batch arms against observed data.

    Relies on the model predictions stored on the generator_runs of batch.

    Args:
        batch: Batch on which to perform analysis.
        data: Observed data for the batch.
    Returns:
        AxPlotConfig for the plot.
    """
    insample_data: Dict[str, PlotInSampleArm] = {}
    metric_names = list(data.df["metric_name"].unique())
    for struct in batch.generator_run_structs:
        generator_run = struct.generator_run
        if generator_run.model_predictions is None:
            continue
        for i, arm in enumerate(generator_run.arms):
            arm_data = {
                "name": arm.name_or_short_signature,
                "y": {},
                "se": {},
                "parameters": arm.parameters,
                "y_hat": {},
                "se_hat": {},
                "context_stratum": None,
            }
            predictions = generator_run.model_predictions
            for _, row in data.df[
                data.df["arm_name"] == arm.name_or_short_signature
            ].iterrows():
                metric_name = row["metric_name"]
                # pyre-fixme[16]: Optional type has no attribute `__setitem__`.
                arm_data["y"][metric_name] = row["mean"]
                # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
                #  typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
                #  str]` has no attribute `__setitem__`.
                arm_data["se"][metric_name] = row["sem"]
                # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
                arm_data["y_hat"][metric_name] = predictions[0][metric_name][i]
                # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any,
                #  typing.Any], Dict[str, typing.Union[None, bool, float, int, str]],
                #  str]` has no attribute `__setitem__`.
                arm_data["se_hat"][metric_name] = predictions[1][metric_name][
                    metric_name
                ][i]
            # pyre-fixme[6]: Expected `Optional[Dict[str, Union[float, str]]]` for 1s...
            insample_data[arm.name_or_short_signature] = PlotInSampleArm(**arm_data)
    if not insample_data:
        raise ValueError("No model predictions present on the batch.")
    plot_data = PlotData(
        metrics=metric_names,
        in_sample=insample_data,
        out_of_sample=None,
        status_quo_name=None,
    )

    fig = _obs_vs_pred_dropdown_plot(data=plot_data, rel=False)
    fig["layout"]["title"] = "Cross-validation"
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)