in assets/model_monitoring/components/src/generation_safety_quality/annotation_compute_metrics/run.py [0:0]
def compute_metrics(histogram_df, threshold_args, metric_names):
"""Compute metrics for given histogram and return them."""
spark = init_spark()
# Cast to float because metric_value was integer so far
# but we're adding percentages now.
histogram_df = histogram_df.withColumn(
METRIC_VALUE_COLUMN, histogram_df[METRIC_VALUE_COLUMN].cast("float")
)
# remove all but groundedness/fluency/coherence/relevance/similarity from metric names and
# remove duplicates
input_metric_names = [m.strip() for m in metric_names.split(",")]
pruned_metric_names = [re.sub(r'^(.*?)(Groundedness|Fluency|Coherence|Relevance|Similarity)(.*?)$', r'\2', m) for
m in input_metric_names]
compact_metric_names = list(set(pruned_metric_names))
aggregated_metrics_df = histogram_df.withColumn(GROUP_DIMENSION_COLUMN, lit(""))
# overwrite threshold value column
aggregated_metrics_df = aggregated_metrics_df.withColumn(THRESHOLD_COLUMN, lit("nan").cast("float"))
metadata_schema = StructType(
[
StructField(GROUP_COLUMN, StringType(), True),
StructField(METRIC_VALUE_COLUMN, StringType(), True),
StructField(METRIC_NAME_COLUMN, StringType(), True),
StructField(THRESHOLD_COLUMN, StringType(), True),
StructField(GROUP_DIMENSION_COLUMN, StringType(), True)
]
)
total_pr_violation_counts = 0
total_avg_violation_counts = 0
for metric_name in compact_metric_names:
passrate_threshold = threshold_args[f"{metric_name.lower()}_passrate_threshold"]
full_pass_rate_metric_name = f"Aggregated{metric_name}PassRate"
full_per_instance_score_metric_name = f"Acceptable{metric_name}ScorePerInstance"
average_score_metric_name = f"Average{metric_name}Score"
average_threshold_name = f"average_{metric_name.lower()}_threshold"
if average_threshold_name in threshold_args:
average_threshold = threshold_args[average_threshold_name]
else:
average_threshold = DEFAULT_AVERAGE_THRESHOLD
passrate_violation_counts_metric_name = f"{metric_name}PassRateViolationCounts"
avg_violation_counts_metric_name = f"{metric_name}AverageScoreViolationCounts"
compute_full_pass_rate_metric = full_pass_rate_metric_name in input_metric_names
if compute_full_pass_rate_metric:
metric_df = spark.createDataFrame(
[
(
"",
_calculate_passrate(histogram_df, metric_name),
full_pass_rate_metric_name,
passrate_threshold,
""
)
],
metadata_schema,
)
aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
if full_per_instance_score_metric_name not in input_metric_names:
aggregated_metrics_df = aggregated_metrics_df.filter(col(METRIC_NAME_COLUMN)
!= full_per_instance_score_metric_name)
else:
threshold = _get_per_instance_threshold(histogram_df, metric_name)
threshold_row = spark.createDataFrame(
[
(
"",
"",
full_per_instance_score_metric_name,
threshold,
"",
)
],
metadata_schema,
)
aggregated_metrics_df = aggregated_metrics_df.union(threshold_row)
# for now, compute average score if either average score or pass rate is requested
compute_average = average_score_metric_name in input_metric_names
compute_average = compute_average or compute_full_pass_rate_metric
if compute_average:
average_metric_score = _calculate_average_metric_score(
histogram_df, full_per_instance_score_metric_name)
metric_df = spark.createDataFrame(
[
(
"",
average_metric_score,
average_score_metric_name,
str(average_threshold),
"",
)
],
metadata_schema,
)
aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
compute_pr_violation_counts = passrate_violation_counts_metric_name in input_metric_names
compute_pr_violation_counts = compute_pr_violation_counts or compute_full_pass_rate_metric
if compute_pr_violation_counts:
violation_counts = _calculate_violation_counts(histogram_df, metric_name)
total_pr_violation_counts += violation_counts
metric_df = spark.createDataFrame(
[
(
"",
violation_counts,
passrate_violation_counts_metric_name,
"",
"",
)
],
metadata_schema,
)
aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
compute_avg_violation_counts = avg_violation_counts_metric_name in input_metric_names
compute_avg_violation_counts = compute_avg_violation_counts or compute_average
if compute_avg_violation_counts:
violation_counts = 1 if average_metric_score < average_threshold else 0
total_avg_violation_counts += violation_counts
metric_df = spark.createDataFrame(
[
(
"",
violation_counts,
avg_violation_counts_metric_name,
"",
"",
)
],
metadata_schema,
)
aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
metric_df = spark.createDataFrame(
[
(
"",
total_pr_violation_counts,
TOTAL_PASS_RATE_VIOLATION_COUNTS,
"",
"",
),
(
"",
total_avg_violation_counts,
TOTAL_AVERAGE_SCORE_VIOLATION_COUNTS,
"",
"",
)
],
metadata_schema,
)
aggregated_metrics_df = aggregated_metrics_df.union(metric_df)
return aggregated_metrics_df