def compute_metrics()

in assets/model_monitoring/components/src/generation_safety_quality/annotation_compute_metrics/run.py [0:0]


def compute_metrics(histogram_df, threshold_args, metric_names):
    """Compute metrics for given histogram and return them."""
    spark = init_spark()
    # Cast to float because metric_value was integer so far
    # but we're adding percentages now.
    histogram_df = histogram_df.withColumn(
        METRIC_VALUE_COLUMN, histogram_df[METRIC_VALUE_COLUMN].cast("float")
    )
    # remove all but groundedness/fluency/coherence/relevance/similarity from metric names and
    # remove duplicates
    input_metric_names = [m.strip() for m in metric_names.split(",")]
    pruned_metric_names = [re.sub(r'^(.*?)(Groundedness|Fluency|Coherence|Relevance|Similarity)(.*?)$', r'\2', m) for
                           m in input_metric_names]
    compact_metric_names = list(set(pruned_metric_names))

    aggregated_metrics_df = histogram_df.withColumn(GROUP_DIMENSION_COLUMN, lit(""))

    # overwrite threshold value column
    aggregated_metrics_df = aggregated_metrics_df.withColumn(THRESHOLD_COLUMN, lit("nan").cast("float"))
    metadata_schema = StructType(
            [
                StructField(GROUP_COLUMN, StringType(), True),
                StructField(METRIC_VALUE_COLUMN, StringType(), True),
                StructField(METRIC_NAME_COLUMN, StringType(), True),
                StructField(THRESHOLD_COLUMN, StringType(), True),
                StructField(GROUP_DIMENSION_COLUMN, StringType(), True)
            ]
        )
    total_pr_violation_counts = 0
    total_avg_violation_counts = 0
    for metric_name in compact_metric_names:
        passrate_threshold = threshold_args[f"{metric_name.lower()}_passrate_threshold"]
        full_pass_rate_metric_name = f"Aggregated{metric_name}PassRate"
        full_per_instance_score_metric_name = f"Acceptable{metric_name}ScorePerInstance"
        average_score_metric_name = f"Average{metric_name}Score"
        average_threshold_name = f"average_{metric_name.lower()}_threshold"
        if average_threshold_name in threshold_args:
            average_threshold = threshold_args[average_threshold_name]
        else:
            average_threshold = DEFAULT_AVERAGE_THRESHOLD
        passrate_violation_counts_metric_name = f"{metric_name}PassRateViolationCounts"
        avg_violation_counts_metric_name = f"{metric_name}AverageScoreViolationCounts"
        compute_full_pass_rate_metric = full_pass_rate_metric_name in input_metric_names
        if compute_full_pass_rate_metric:
            metric_df = spark.createDataFrame(
                    [
                        (
                            "",
                            _calculate_passrate(histogram_df, metric_name),
                            full_pass_rate_metric_name,
                            passrate_threshold,
                            ""
                        )
                    ],
                    metadata_schema,
                )
            aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
        if full_per_instance_score_metric_name not in input_metric_names:
            aggregated_metrics_df = aggregated_metrics_df.filter(col(METRIC_NAME_COLUMN)
                                                                 != full_per_instance_score_metric_name)
        else:
            threshold = _get_per_instance_threshold(histogram_df, metric_name)
            threshold_row = spark.createDataFrame(
                    [
                        (
                            "",
                            "",
                            full_per_instance_score_metric_name,
                            threshold,
                            "",
                        )
                    ],
                    metadata_schema,
                )
            aggregated_metrics_df = aggregated_metrics_df.union(threshold_row)
        # for now, compute average score if either average score or pass rate is requested
        compute_average = average_score_metric_name in input_metric_names
        compute_average = compute_average or compute_full_pass_rate_metric
        if compute_average:
            average_metric_score = _calculate_average_metric_score(
                histogram_df, full_per_instance_score_metric_name)
            metric_df = spark.createDataFrame(
                    [
                        (
                            "",
                            average_metric_score,
                            average_score_metric_name,
                            str(average_threshold),
                            "",
                        )
                    ],
                    metadata_schema,
                )
            aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
        compute_pr_violation_counts = passrate_violation_counts_metric_name in input_metric_names
        compute_pr_violation_counts = compute_pr_violation_counts or compute_full_pass_rate_metric
        if compute_pr_violation_counts:
            violation_counts = _calculate_violation_counts(histogram_df, metric_name)
            total_pr_violation_counts += violation_counts
            metric_df = spark.createDataFrame(
                    [
                        (
                            "",
                            violation_counts,
                            passrate_violation_counts_metric_name,
                            "",
                            "",
                        )
                    ],
                    metadata_schema,
                )
            aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
        compute_avg_violation_counts = avg_violation_counts_metric_name in input_metric_names
        compute_avg_violation_counts = compute_avg_violation_counts or compute_average
        if compute_avg_violation_counts:
            violation_counts = 1 if average_metric_score < average_threshold else 0
            total_avg_violation_counts += violation_counts
            metric_df = spark.createDataFrame(
                    [
                        (
                            "",
                            violation_counts,
                            avg_violation_counts_metric_name,
                            "",
                            "",
                        )
                    ],
                    metadata_schema,
                )
            aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
    metric_df = spark.createDataFrame(
            [
                (
                    "",
                    total_pr_violation_counts,
                    TOTAL_PASS_RATE_VIOLATION_COUNTS,
                    "",
                    "",
                ),
                (
                    "",
                    total_avg_violation_counts,
                    TOTAL_AVERAGE_SCORE_VIOLATION_COUNTS,
                    "",
                    "",
                )
            ],
            metadata_schema,
        )
    aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
    return aggregated_metrics_df