def get_cohorts_data()

in src/responsibleai/rai_analyse/_score_card/_rai_insight_data.py [0:0]


    def get_cohorts_data(self):
        cohorts_data = {
            "error_analysis_min": [],
            "error_analysis_max": [],
            "cohorts": [],
        }

        cohort_short_label_generator = AlphabetLabelIterator()
        for metric in self.config["Metrics"]:
            if "Cohorts" in self.config:
                for c in self.config["Cohorts"]:
                    if c in self.config["cohorts_definition"].keys():
                        code = self.config["cohorts_definition"][c]
                        filtered_dataset = self.data.get_filtered_dataset(code)

                        cd = {
                            "label": c,
                            "short_label": next(cohort_short_label_generator),
                            metric: get_metric(
                                metric,
                                filtered_dataset["y_test"],
                                filtered_dataset["y_pred"],
                                **self.get_metric_kwargs(),
                            ),
                            "population": len(filtered_dataset["y_pred"]) /
                            len(self.data.get_y_test()),
                        }
                        if "threshold" in self.config["Metrics"][metric]:
                            cd["threshold"] = self.config["Metrics"][metric][
                                "threshold"
                            ][1]

                        cohorts_data["cohorts"].append(cd)

            try:
                if "error_analysis" in self.data.component_path_prefix:
                    ea_data = self.data.get_error_analysis_data(metric)[0]

                    tree = ea_data.tree
                    treemap = self.data.to_tree_map(tree)
                    min_nodes, max_nodes = self.data.get_min_max_nodes(treemap, 3)

                    def get_cohorts_data(nodes):
                        ret = []
                        ea_short_label_generator = AlphabetLabelIterator()
                        for node in nodes:
                            filter_conditions = self.data.get_filter_conditions(
                                treemap, node["id"]
                            )
                            cd = {
                                "label": " <br>AND<br>".join(filter_conditions)
                                if len(filter_conditions) > 0
                                else "All Data",
                                "short_label": next(ea_short_label_generator),
                                metric: treemap[node["id"]]["metricValue"],
                                "population": treemap[node["id"]]["size"] /
                                len(self.data.get_y_test()),
                            }
                            if "threshold" in self.config["Metrics"][metric]:
                                cd["threshold"] = self.config["Metrics"][metric][
                                    "threshold"
                                ][1]

                            ret.append(cd)
                        return ret

                    cohorts_data["error_analysis_min"].extend(
                        get_cohorts_data(min_nodes)
                    )
                    cohorts_data["error_analysis_max"].extend(
                        get_cohorts_data(max_nodes)
                    )
            except Exception as ex:
                print("Error in getting error analysis metrics: {}, skipping", ex)

        return cohorts_data