def compute_data_quality_metrics()

in assets/model_monitoring/components/src/data_quality_compute_metrics/compute_data_quality_metrics.py [0:0]


def compute_data_quality_metrics(df, data_stats_table, override_numerical_features, override_categorical_features):
    """Compute data quality metrics."""
    #########################
    # PREPARE THE DATA
    #########################
    # Cache the DataFrames
    df.cache()
    data_stats_table.cache()

    data_stats_table_mod = modify_dataType(data_stats_table)
    numerical_columns, categorical_columns = get_numerical_and_categorical_cols(df,
                                                                                override_numerical_features,
                                                                                override_categorical_features)

    modified_categorical_columns = modify_categorical_columns(df, categorical_columns)
    #########################
    # COMPUTE VIOLATIONS
    #########################
    # 1. NULL TYPE
    null_count_dtype = get_null_count(df)

    # HIERARCHY 1: IMPUTE MISSING VALUES AFTER COUNTING THEM
    df = impute_numericals_with_median(df, numerical_columns)
    df = impute_categorical_with_mode(df, modified_categorical_columns)

    # 2. DATA TYPE VIOLATION
    df, dtype_violation_df = compute_dtype_violation_count_modify_dataset(
        df=df, data_stats_table_mod=data_stats_table_mod
    )
    # HIERARCHY 2: CHANGE D-TYPE OF AFFECTED COLUMN TO BASELINE SCHEMA
    # THIS HAPPENS IN THE `compute_dtype_violation_count_modify_dataset` FUNCTION with df being overwritten

    # 3. OUT OF BOUNDS
    max_violation_df = compute_max_violation(df=df,
                                             data_stats_table=data_stats_table,
                                             numerical_columns=numerical_columns)
    min_violation_df = compute_min_violation(df=df,
                                             data_stats_table=data_stats_table,
                                             numerical_columns=numerical_columns)
    threshold_violation_df = compute_set_violation(
        df=df, data_stats_table=data_stats_table, categorical_columns=modified_categorical_columns
    )
    data_stats_table.unpersist()
    data_stats_table_mod.unpersist()
    #########################
    # JOIN ALL TABLES
    #########################
    violation_df = max_violation_df.unionByName(min_violation_df)
    min_violation_df.unpersist()  # release pre-join data frames from memory
    max_violation_df.unpersist()

    temp_select = null_count_dtype.select(
        ["featureName", "violationCount", "metricName"]
    )
    violation_df = temp_select.unionByName(violation_df)
    temp_select.unpersist()

    violation_df = violation_df.unionByName(threshold_violation_df)
    threshold_violation_df.unpersist()

    violation_df = violation_df.unionByName(dtype_violation_df)
    dtype_violation_df.unpersist()

    dtype_df = get_df_schema(df)
    violation_df = dtype_df.join(violation_df, ["featureName"], how="right")
    dtype_df.unpersist()

    # ADD ROW COUNT
    df_length = (
        spark.createDataFrame(
            [(df.count(), "RowCount")],
            schema=StructType(
                [
                    StructField("violationCount", IntegerType(), True),
                    StructField("metricName", StringType(), True),
                ]
            ),
        )
        .withColumn("featureName", lit(""))
        .withColumn("dataType", lit(""))
    )
    # add the new row to the original DataFrame using unionByName()
    violation_df = violation_df.unionByName(df_length)

    violation_df_remapped = violation_df.withColumn(
        "metricName",
        when(
            violation_df.metricName.endswith("maxValueOutOfRange"),
            regexp_replace(
                violation_df.metricName, "maxValueOutOfRange", "OutOfBounds"
            ),
        )
        .when(
            violation_df.metricName.endswith("minValueOutOfRange"),
            regexp_replace(
                violation_df.metricName, "minValueOutOfRange", "OutOfBounds"
            ),
        )
        .when(
            violation_df.metricName.endswith("setValueOutOfRange"),
            regexp_replace(
                violation_df.metricName, "setValueOutOfRange", "OutOfBounds"
            ),
        )
        .otherwise(violation_df.metricName),
    )

    violation_df_remapped = (
        violation_df_remapped.select(
            ["featureName", "metricName", "violationCount", "dataType"]
        )
        .groupby(["featureName", "metricName", "dataType"])
        .sum()
        .withColumnRenamed("sum(violationCount)", "violationCount")
    )

    # COMPUTE RATIOS
    # 'len_df' is the name of the column to divide which is the row count needed for the ratios
    len_df = (
        violation_df_remapped.filter(col("metricName") == "RowCount")
        .select("violationCount")
        .collect()[0][0]
    )
    # divide the column by the row count using
    violation_df_remapped = violation_df_remapped.withColumn(
        "metricValue", round(violation_df_remapped["violationCount"] / lit(len_df), 5)
    )

    # REMAP THE DATA TYPES
    violation_df_remapped = violation_df_remapped.withColumn(
        "dataType",
        when(col("featureName").isin(categorical_columns), "Categorical")
        .when(col("featureName").isin(numerical_columns), "Numerical")
        .otherwise(col("dataType"))
    )

    #########################
    # ALIGN COLUMN NAMING
    #########################
    # RENAME METRIC VALUE
    violation_df_remapped = violation_df_remapped.withColumn(
        "metricName",
        when(
            col("metricName") != "RowCount", concat(col("metricName"), lit("Rate"))
        ).otherwise(col("metricName")),
    )

    # MOVE ROW COUNT "metricValue" TO THE RIGHT COLUMN AND SET VIOLATION COUNT TO 0
    violation_df_remapped = violation_df_remapped.withColumn(
        "metricValue",
        when(col("metricName") == "RowCount", col("violationCount")).otherwise(
            col("metricValue")
        ),
    ).withColumn(
        "violationCount",
        when(col("metricName") == "RowCount", 0).otherwise(col("violationCount")),
    )

    return violation_df_remapped