def get_cohorts_page()

in src/responsibleai/rai_analyse/_score_card/common_components.py [0:0]


def get_cohorts_page(data, metric_config):
    left_elems = div(
        p("Observe evidence of model performance across your passed cohorts:"),
        _class="left",
    )
    cp_heading_main = div(_class="main")
    heading_section = str(
        div(
            get_page_divider("Cohorts"), left_elems, cp_heading_main, _class="container"
        )
    )

    # start section for each predefined cohort
    def populate_cp_container(key, container, m, data):
        if key not in data.keys():
            return

        def get_regression_bar_plot(d, m):
            first_data_point = next(iter(d), None)
            threshold = None
            if first_data_point:
                threshold = first_data_point.get("threshold", None)
            y_data = [
                [y["short_label"], str(int(y["population"] * 100)) + "% n"] for y in d
            ]
            y_data = ["<br>".join(y) for y in y_data]
            x_data = [x[m] for x in d]
            if m in ["accuracy_score", "recall_score", "precision_score", "f1_score"]:
                max_x = 1
            else:
                max_x = max(x_data)
            x_data = [[x, max_x - x] for x in x_data]
            legend = [m]

            return get_bar_plot(
                list(reversed(y_data)),
                list(reversed(x_data)),
                legend=legend,
                threshold=threshold,
            )

        message_lookup = {
            "cohorts": {
                "left_heading": "My Cohorts",
                "main_heading": "My prebuilt dataset cohorts",
            },
            "error_analysis_max": {
                "left_heading": "Highest ranked cohorts",
                "main_heading": "Highest ranked cohorts",
            },
            "error_analysis_min": {
                "left_heading": "Lowest ranked cohorts",
                "main_heading": "Lowest ranked cohorts",
            },
        }

        filtered_data = [d for d in data[key] if m in d.keys()]
        if len(filtered_data) == 0:
            return

        container[key].append(
            get_cohorts_performance_container(
                filtered_data, m, get_regression_bar_plot, message_lookup[key]
            )
        )

    cohort_performance_containers = {
        "cohorts": [],
        "error_analysis_max": [],
        "error_analysis_min": [],
    }
    for k in ["cohorts", "error_analysis_max", "error_analysis_min"]:
        for m in metric_config:
            populate_cp_container(k, cohort_performance_containers, m, data)

    cohort_performance_section = "".join(cohort_performance_containers["cohorts"])
    cohort_performance_section = cohort_performance_section + "".join(
        cohort_performance_containers["error_analysis_max"]
    )
    cohort_performance_section = cohort_performance_section + "".join(
        cohort_performance_containers["error_analysis_min"]
    )

    return str(heading_section + cohort_performance_section)