in assets/model_monitoring/components/src/data_quality_compute_metrics/compute_data_quality_metrics.py [0:0]
def compute_data_quality_metrics(df, data_stats_table, override_numerical_features, override_categorical_features):
"""Compute data quality metrics."""
#########################
# PREPARE THE DATA
#########################
# Cache the DataFrames
df.cache()
data_stats_table.cache()
data_stats_table_mod = modify_dataType(data_stats_table)
numerical_columns, categorical_columns = get_numerical_and_categorical_cols(df,
override_numerical_features,
override_categorical_features)
modified_categorical_columns = modify_categorical_columns(df, categorical_columns)
#########################
# COMPUTE VIOLATIONS
#########################
# 1. NULL TYPE
null_count_dtype = get_null_count(df)
# HIERARCHY 1: IMPUTE MISSING VALUES AFTER COUNTING THEM
df = impute_numericals_with_median(df, numerical_columns)
df = impute_categorical_with_mode(df, modified_categorical_columns)
# 2. DATA TYPE VIOLATION
df, dtype_violation_df = compute_dtype_violation_count_modify_dataset(
df=df, data_stats_table_mod=data_stats_table_mod
)
# HIERARCHY 2: CHANGE D-TYPE OF AFFECTED COLUMN TO BASELINE SCHEMA
# THIS HAPPENS IN THE `compute_dtype_violation_count_modify_dataset` FUNCTION with df being overwritten
# 3. OUT OF BOUNDS
max_violation_df = compute_max_violation(df=df,
data_stats_table=data_stats_table,
numerical_columns=numerical_columns)
min_violation_df = compute_min_violation(df=df,
data_stats_table=data_stats_table,
numerical_columns=numerical_columns)
threshold_violation_df = compute_set_violation(
df=df, data_stats_table=data_stats_table, categorical_columns=modified_categorical_columns
)
data_stats_table.unpersist()
data_stats_table_mod.unpersist()
#########################
# JOIN ALL TABLES
#########################
violation_df = max_violation_df.unionByName(min_violation_df)
min_violation_df.unpersist() # release pre-join data frames from memory
max_violation_df.unpersist()
temp_select = null_count_dtype.select(
["featureName", "violationCount", "metricName"]
)
violation_df = temp_select.unionByName(violation_df)
temp_select.unpersist()
violation_df = violation_df.unionByName(threshold_violation_df)
threshold_violation_df.unpersist()
violation_df = violation_df.unionByName(dtype_violation_df)
dtype_violation_df.unpersist()
dtype_df = get_df_schema(df)
violation_df = dtype_df.join(violation_df, ["featureName"], how="right")
dtype_df.unpersist()
# ADD ROW COUNT
df_length = (
spark.createDataFrame(
[(df.count(), "RowCount")],
schema=StructType(
[
StructField("violationCount", IntegerType(), True),
StructField("metricName", StringType(), True),
]
),
)
.withColumn("featureName", lit(""))
.withColumn("dataType", lit(""))
)
# add the new row to the original DataFrame using unionByName()
violation_df = violation_df.unionByName(df_length)
violation_df_remapped = violation_df.withColumn(
"metricName",
when(
violation_df.metricName.endswith("maxValueOutOfRange"),
regexp_replace(
violation_df.metricName, "maxValueOutOfRange", "OutOfBounds"
),
)
.when(
violation_df.metricName.endswith("minValueOutOfRange"),
regexp_replace(
violation_df.metricName, "minValueOutOfRange", "OutOfBounds"
),
)
.when(
violation_df.metricName.endswith("setValueOutOfRange"),
regexp_replace(
violation_df.metricName, "setValueOutOfRange", "OutOfBounds"
),
)
.otherwise(violation_df.metricName),
)
violation_df_remapped = (
violation_df_remapped.select(
["featureName", "metricName", "violationCount", "dataType"]
)
.groupby(["featureName", "metricName", "dataType"])
.sum()
.withColumnRenamed("sum(violationCount)", "violationCount")
)
# COMPUTE RATIOS
# 'len_df' is the name of the column to divide which is the row count needed for the ratios
len_df = (
violation_df_remapped.filter(col("metricName") == "RowCount")
.select("violationCount")
.collect()[0][0]
)
# divide the column by the row count using
violation_df_remapped = violation_df_remapped.withColumn(
"metricValue", round(violation_df_remapped["violationCount"] / lit(len_df), 5)
)
# REMAP THE DATA TYPES
violation_df_remapped = violation_df_remapped.withColumn(
"dataType",
when(col("featureName").isin(categorical_columns), "Categorical")
.when(col("featureName").isin(numerical_columns), "Numerical")
.otherwise(col("dataType"))
)
#########################
# ALIGN COLUMN NAMING
#########################
# RENAME METRIC VALUE
violation_df_remapped = violation_df_remapped.withColumn(
"metricName",
when(
col("metricName") != "RowCount", concat(col("metricName"), lit("Rate"))
).otherwise(col("metricName")),
)
# MOVE ROW COUNT "metricValue" TO THE RIGHT COLUMN AND SET VIOLATION COUNT TO 0
violation_df_remapped = violation_df_remapped.withColumn(
"metricValue",
when(col("metricName") == "RowCount", col("violationCount")).otherwise(
col("metricValue")
),
).withColumn(
"violationCount",
when(col("metricName") == "RowCount", 0).otherwise(col("violationCount")),
)
return violation_df_remapped