in evals/elsuite/identifying_variables/eval.py [0:0]
def _compute_grouped_metrics(self, metrics: List[Dict]) -> Dict[str, float]:
"""
Computes metrics aggregated across samples grouped by
- number of variables
- number of roots in random forest
- number of control variables
- number of hypotheses
- max correlation depth
"""
metric_to_agg_func = {
"hyp_valid_acc": np.mean,
"violation_count": np.sum,
"violation_rate": np.mean,
"ctrl_nDCG": np.nanmean,
"ctrl_recall": np.nanmean,
"ctrl_fallout": np.nanmean,
"ind_acc": np.nanmean,
"dep_acc": np.nanmean,
}
raw_metric_names = [
"hyp_valid_correct",
"violation",
"violation",
"ctrl_nDCG",
"ctrl_recall",
"ctrl_fallout",
"ind_correct",
"dep_correct",
]
group_to_bins = {
"n_vars": np.arange(constants.MIN_VARS, constants.MAX_VARS + 1),
"n_roots": np.arange(1, constants.MAX_VARS + 1),
"n_ctrl_vars": np.arange(0, (constants.MAX_VARS - 2) + 1),
"n_hyps": np.arange(constants.MIN_HYPS, constants.MAX_HYPS + 1),
"max_corr_depth": np.arange(1, constants.MAX_VARS),
}
grouped_metrics = {
f"{metric}-{group}-{g_bin}": []
for metric in metric_to_agg_func.keys()
for group in group_to_bins.keys()
for g_bin in group_to_bins[group]
}
for log_entry in metrics:
causal_graph = nx.from_dict_of_lists(log_entry["causal_graph"], create_using=nx.DiGraph)
ctrl_vars = log_entry["gold_answer"]["ctrl_vars"]
dep_var = log_entry["gold_answer"]["dep_var"]
group_to_bin = {
"n_vars": causal_graph.number_of_nodes(),
"n_roots": len(graph_utils.find_graph_roots(causal_graph)),
"n_ctrl_vars": len(ctrl_vars) if ctrl_vars is not None else None,
"n_hyps": log_entry["n_hyps"],
"max_corr_depth": graph_utils.find_farthest_node(causal_graph, dep_var)[1]
if dep_var is not None
else None,
}
for group, g_bin in group_to_bin.items():
if g_bin is not None:
for metric, raw_metric in zip(metric_to_agg_func.keys(), raw_metric_names):
grouped_metrics[f"{metric}-{group}-{g_bin}"].append(log_entry[raw_metric])
# aggregate
grouped_metrics = {
k: metric_to_agg_func[k.split("-")[0]](v)
# signal empty groups with np.nan
if len(v) > 0 else np.nan
for k, v in grouped_metrics.items()
}
return grouped_metrics