def _compute_grouped_metrics()

in evals/elsuite/identifying_variables/eval.py [0:0]


    def _compute_grouped_metrics(self, metrics: List[Dict]) -> Dict[str, float]:
        """
        Computes metrics aggregated across samples grouped by
          - number of variables
          - number of roots in random forest
          - number of control variables
          - number of hypotheses
          - max correlation depth
        """
        metric_to_agg_func = {
            "hyp_valid_acc": np.mean,
            "violation_count": np.sum,
            "violation_rate": np.mean,
            "ctrl_nDCG": np.nanmean,
            "ctrl_recall": np.nanmean,
            "ctrl_fallout": np.nanmean,
            "ind_acc": np.nanmean,
            "dep_acc": np.nanmean,
        }
        raw_metric_names = [
            "hyp_valid_correct",
            "violation",
            "violation",
            "ctrl_nDCG",
            "ctrl_recall",
            "ctrl_fallout",
            "ind_correct",
            "dep_correct",
        ]
        group_to_bins = {
            "n_vars": np.arange(constants.MIN_VARS, constants.MAX_VARS + 1),
            "n_roots": np.arange(1, constants.MAX_VARS + 1),
            "n_ctrl_vars": np.arange(0, (constants.MAX_VARS - 2) + 1),
            "n_hyps": np.arange(constants.MIN_HYPS, constants.MAX_HYPS + 1),
            "max_corr_depth": np.arange(1, constants.MAX_VARS),
        }
        grouped_metrics = {
            f"{metric}-{group}-{g_bin}": []
            for metric in metric_to_agg_func.keys()
            for group in group_to_bins.keys()
            for g_bin in group_to_bins[group]
        }
        for log_entry in metrics:
            causal_graph = nx.from_dict_of_lists(log_entry["causal_graph"], create_using=nx.DiGraph)
            ctrl_vars = log_entry["gold_answer"]["ctrl_vars"]
            dep_var = log_entry["gold_answer"]["dep_var"]
            group_to_bin = {
                "n_vars": causal_graph.number_of_nodes(),
                "n_roots": len(graph_utils.find_graph_roots(causal_graph)),
                "n_ctrl_vars": len(ctrl_vars) if ctrl_vars is not None else None,
                "n_hyps": log_entry["n_hyps"],
                "max_corr_depth": graph_utils.find_farthest_node(causal_graph, dep_var)[1]
                if dep_var is not None
                else None,
            }
            for group, g_bin in group_to_bin.items():
                if g_bin is not None:
                    for metric, raw_metric in zip(metric_to_agg_func.keys(), raw_metric_names):
                        grouped_metrics[f"{metric}-{group}-{g_bin}"].append(log_entry[raw_metric])

        # aggregate
        grouped_metrics = {
            k: metric_to_agg_func[k.split("-")[0]](v)
            # signal empty groups with np.nan
            if len(v) > 0 else np.nan
            for k, v in grouped_metrics.items()
        }
        return grouped_metrics