in ax/modelbridge/cross_validation.py [0:0]
def compute_diagnostics(result: List[CVResult]) -> CVDiagnostics:
"""Computes diagnostics for given cross validation results.
It provides a dictionary with values for the following diagnostics, for
each metric:
- 'Mean prediction CI': the average width of the CIs at each of the CV
predictions, relative to the observed mean.
- 'MAPE': mean absolute percentage error of the estimated mean relative
to the observed mean.
- 'Total raw effect': the percent change from the smallest observed
mean to the largest observed mean.
- 'Correlation coefficient': the Pearson correlation of the estimated
and observed means.
- 'Rank correlation': the Spearman correlation of the estimated
and observed means.
- 'Fisher exact test p': we test if the model is able to distinguish the
bottom half of the observations from the top half, using Fisher's
exact test and the observed/estimated means. A low p value indicates
that the model has some ability to identify good arms. A high p value
indicates that the model cannot identify arms better than chance, or
that the observations are too noisy to be able to tell.
Each of these is returned as a dictionary from metric name to value for
that metric.
Args:
result: Output of cross_validate
Returns:
A dictionary keyed by diagnostic name with results as described above.
"""
# Extract per-metric outcomes from CVResults.
y_obs = defaultdict(list)
y_pred = defaultdict(list)
se_pred = defaultdict(list)
for res in result:
for j, metric_name in enumerate(res.observed.data.metric_names):
y_obs[metric_name].append(res.observed.data.means[j])
# Find the matching prediction
k = res.predicted.metric_names.index(metric_name)
y_pred[metric_name].append(res.predicted.means[k])
se_pred[metric_name].append(np.sqrt(res.predicted.covariance[k, k]))
diagnostic_fns = {
MEAN_PREDICTION_CI: _mean_prediction_ci,
MAPE: _mape,
TOTAL_RAW_EFFECT: _total_raw_effect,
CORRELATION_COEFFICIENT: _correlation_coefficient,
RANK_CORRELATION: _rank_correlation,
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
LOG_LIKELIHOOD: _log_likelihood,
}
diagnostics: Dict[str, Dict[str, float]] = defaultdict(dict)
# Get all per-metric diagnostics.
for metric_name in y_obs:
for name, fn in diagnostic_fns.items():
diagnostics[name][metric_name] = fn(
y_obs=np.array(y_obs[metric_name]),
y_pred=np.array(y_pred[metric_name]),
se_pred=np.array(se_pred[metric_name]),
)
return diagnostics