assets/model_monitoring/components/src/shared_utilities/df_utils.py (167 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """This file contains additional utilities that are applicable to dataframe.""" import pyspark.sql as pyspark_sql from enum import Enum from shared_utilities.momo_exceptions import InvalidInputError from shared_utilities.event_utils import post_warning_event from pyspark.sql.utils import AnalysisException from typing import Optional class NoCommonColumnsApproach(Enum): """Enum for no common columns approach.""" IGNORE = 0 WARNING = 1 ERROR = 2 data_type_long_group = ["long", "int", "bigint", "short", "tinyint", "smallint"] data_type_numerical_group = ["float", "double", "decimal"] data_type_categorical_group = ["string", "boolean", "timestamp", "date", "binary"] def is_numerical(column, column_dtype_map: dict, feature_type_override_map: dict, df): """Check if int column should be numerical.""" is_categorical_col = is_categorical(column, column_dtype_map, feature_type_override_map, df) return None if is_categorical_col is None else not is_categorical_col def is_categorical(column, column_dtype_map: dict, feature_type_override_map: dict, df): """Check if int column should be categorical.""" if feature_type_override_map.get(column, None) == "categorical": return True if feature_type_override_map.get(column, None) == "numerical": return False if column_dtype_map[column] in data_type_categorical_group: return True if column_dtype_map[column] in data_type_numerical_group: return False if column_dtype_map[column] in data_type_long_group: distinct_value_ratio = get_distinct_ratio(df.select(column).rdd.flatMap(lambda x: x).collect()) return distinct_value_ratio < 0.05 print(f"Unknown column type: {column_dtype_map[column]}, column name: {column}") return None def get_numerical_cols_with_df_with_override( df, override_numerical_features, override_categorical_features, column_dtype_map=None) -> list: """Get numerical columns from all columns with dataframe.""" column_dtype_map = dict(df.dtypes) if column_dtype_map is None else column_dtype_map feature_type_override_map = get_feature_type_override_map(override_numerical_features, override_categorical_features) numerical_columns = [ column for column in column_dtype_map if is_numerical(column, column_dtype_map, feature_type_override_map, df) ] return numerical_columns def get_categorical_cols_with_df_with_override( df, override_numerical_features, override_categorical_features, column_dtype_map=None) -> list: """Get categorical columns from all columns with dataframe.""" column_dtype_map = dict(df.dtypes) if column_dtype_map is None else column_dtype_map feature_type_override_map = get_feature_type_override_map(override_numerical_features, override_categorical_features) categorical_columns = [ column for column in column_dtype_map if is_categorical(column, column_dtype_map, feature_type_override_map, df) ] return categorical_columns def get_numerical_and_categorical_cols( df, override_numerical_features, override_categorical_features, column_dtype_map=None): """Get numerical and categorical columns from all columns with dataframe.""" return (get_numerical_cols_with_df_with_override(df, override_numerical_features, override_categorical_features, column_dtype_map), get_categorical_cols_with_df_with_override(df, override_numerical_features, override_categorical_features, column_dtype_map)) def get_feature_type_override_map(override_numerical_features: str, override_categorical_features: str) -> dict: """Generate feature type override map with key of feature name and value of "numerical"/"categorical".""" feature_type_override_map = {} if override_categorical_features: for cat_feature in override_categorical_features.split(','): feature_type_override_map[cat_feature] = "categorical" if override_numerical_features: for num_feature in override_numerical_features.split(','): feature_type_override_map[num_feature] = "numerical" return feature_type_override_map def get_distinct_ratio(column): """Get distict ratio for values in a column.""" distinct_values = len(set(column)) total_values = len(column) return distinct_values / total_values def get_common_columns( baseline_df: pyspark_sql.DataFrame, production_df: pyspark_sql.DataFrame ) -> dict: """Get common columns from baseline and production dataframes.""" baseline_df_dtypes = dict(baseline_df.dtypes) production_df_dtypes = dict(production_df.dtypes) common_columns = {} for (column_name, data_type) in baseline_df_dtypes.items(): if production_df_dtypes.get(column_name) == data_type: common_columns[column_name] = data_type else: # if baseline and target are of different type # and both of them are in [double, float], # We consider them to be double if production_df_dtypes.get(column_name) in data_type_numerical_group \ and baseline_df_dtypes.get(column_name) in data_type_numerical_group: common_columns[column_name] = 'double' # if baseline and target are of different type # and both of them are in [int, long, short] # We consider them to be long elif production_df_dtypes.get(column_name) in data_type_long_group\ and baseline_df_dtypes.get(column_name) in data_type_long_group: common_columns[column_name] = 'long' return common_columns def modify_categorical_columns(df: pyspark_sql.DataFrame, categorical_columns: list) -> list: """ Modify categorical columns, filtering out unsupported or non-meaningful columns. Args: df (pyspark.sql.DataFrame): The input DataFrame categorical_columns: List of categorical columns Returns: modified_categorical_columns: Modified categorical column """ # Only do the data quality check for string type. Ignore all the other types # Ignore bool, time, date categorical columns because they are meaningless for data quality calculation # Ignore binary because it will throw type not supported error for mode modified_categorical_columns = [] dtype_map = dict(df.dtypes) for column in categorical_columns: if dtype_map[column] == "string": modified_categorical_columns.append(column) return modified_categorical_columns def select_columns_from_spark_df(df: pyspark_sql.DataFrame, column_list: list): """Select comlumns from given spark dataFrame.""" column_list = list(map(str.strip, column_list)) df = df.select(column_list) return df def row_has_value(row: pyspark_sql.Row, row_name: str) -> bool: """Check if a row has the given column.""" return row_name in row and row[row_name] is not None and row[row_name] != "" def add_value_if_present( row: pyspark_sql.Row, row_name: str, dict: dict, target_property_name: str ) -> dict: """Add value to a dictionary if it is present in a row.""" if row_has_value(row, row_name): dict[target_property_name] = row[row_name] return dict def try_get_common_columns_with_warning( baseline_df: pyspark_sql.DataFrame, production_df: pyspark_sql.DataFrame ) -> dict: """Get common columns. Post warning to the job and return empty dict.""" return try_get_common_columns(baseline_df, production_df, NoCommonColumnsApproach.WARNING) def try_get_common_columns_with_error( baseline_df: pyspark_sql.DataFrame, production_df: pyspark_sql.DataFrame ) -> dict: """Get common columns. Raise error if dictionary is empty.""" return try_get_common_columns(baseline_df, production_df, NoCommonColumnsApproach.ERROR) def try_get_common_columns( baseline_df: pyspark_sql.DataFrame, production_df: pyspark_sql.DataFrame, no_common_columns_approach=NoCommonColumnsApproach.IGNORE ) -> dict: """ Compute the common columns between baseline and production dataframes. If common columns are not found, conduct different error handling based on no_common_columns_approach. """ common_columns_dict = get_common_columns(baseline_df, production_df) if not common_columns_dict: error_message = ( "Found no common columns between input datasets. Try double-checking" " if there are common columns between the input datasets." " Common columns must have the same names (case-sensitive) and similar data types." ) if no_common_columns_approach == NoCommonColumnsApproach.ERROR: raise InvalidInputError( error_message ) elif no_common_columns_approach == NoCommonColumnsApproach.WARNING: post_warning_event( error_message + " Please visit aka.ms/mlmonitoringhelp for more information." ) return {} # no_common_columns_approach == NoCommonColumnsApproach.IGNORE: else: return {} # returns found common columns. return common_columns_dict def try_get_df_column(df: pyspark_sql.DataFrame, name: str) -> Optional[pyspark_sql.Column]: """Get column if it exists in DF. Return none if column does not exist.""" try: return df[name] except AnalysisException: return None def has_duplicated_columns(df: pyspark_sql.DataFrame) -> bool: """Check if a dataframe has duplicate columns.""" # spark is not case sensitive unless configured otherwise, # set column names in lowercase to check for duplicates. col_names = [col_name.lower() for col_name in df.columns] col_names_set = set(col_names) return len(col_names) > len(col_names_set) def validate_column_names(df): """Validate if column name has dot, which is by deign for accessing nested fields. Throw invalid input error.""" for column in df.columns: if "." in column: raise InvalidInputError(f"Column name {column} has a dot, which is not supported. " + "Please rename the column and retry.")