def compute_diagnostics()

in ax/modelbridge/cross_validation.py [0:0]


def compute_diagnostics(result: List[CVResult]) -> CVDiagnostics:
    """Computes diagnostics for given cross validation results.

    It provides a dictionary with values for the following diagnostics, for
    each metric:

    - 'Mean prediction CI': the average width of the CIs at each of the CV
      predictions, relative to the observed mean.
    - 'MAPE': mean absolute percentage error of the estimated mean relative
      to the observed mean.
    - 'Total raw effect': the percent change from the smallest observed
      mean to the largest observed mean.
    - 'Correlation coefficient': the Pearson correlation of the estimated
      and observed means.
    - 'Rank correlation': the Spearman correlation of the estimated
      and observed means.
    - 'Fisher exact test p': we test if the model is able to distinguish the
      bottom half of the observations from the top half, using Fisher's
      exact test and the observed/estimated means. A low p value indicates
      that the model has some ability to identify good arms. A high p value
      indicates that the model cannot identify arms better than chance, or
      that the observations are too noisy to be able to tell.

    Each of these is returned as a dictionary from metric name to value for
    that metric.

    Args:
        result: Output of cross_validate

    Returns:
        A dictionary keyed by diagnostic name with results as described above.
    """
    # Extract per-metric outcomes from CVResults.
    y_obs = defaultdict(list)
    y_pred = defaultdict(list)
    se_pred = defaultdict(list)
    for res in result:
        for j, metric_name in enumerate(res.observed.data.metric_names):
            y_obs[metric_name].append(res.observed.data.means[j])
            # Find the matching prediction
            k = res.predicted.metric_names.index(metric_name)
            y_pred[metric_name].append(res.predicted.means[k])
            se_pred[metric_name].append(np.sqrt(res.predicted.covariance[k, k]))

    diagnostic_fns = {
        MEAN_PREDICTION_CI: _mean_prediction_ci,
        MAPE: _mape,
        TOTAL_RAW_EFFECT: _total_raw_effect,
        CORRELATION_COEFFICIENT: _correlation_coefficient,
        RANK_CORRELATION: _rank_correlation,
        FISHER_EXACT_TEST_P: _fisher_exact_test_p,
        LOG_LIKELIHOOD: _log_likelihood,
    }

    diagnostics: Dict[str, Dict[str, float]] = defaultdict(dict)
    # Get all per-metric diagnostics.
    for metric_name in y_obs:
        for name, fn in diagnostic_fns.items():
            diagnostics[name][metric_name] = fn(
                y_obs=np.array(y_obs[metric_name]),
                y_pred=np.array(y_pred[metric_name]),
                se_pred=np.array(se_pred[metric_name]),
            )
    return diagnostics