data_validation/data_validation.py (261 lines of code) (raw):

# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import warnings from concurrent.futures import ThreadPoolExecutor import ibis.backends.pandas import pandas import uuid from data_validation import combiner, consts, metadata, util from data_validation.config_manager import ConfigManager from data_validation.query_builder.random_row_builder import RandomRowBuilder from data_validation.schema_validation import SchemaValidation from data_validation.validation_builder import ValidationBuilder """ The DataValidation class is where the code becomes source/target aware The class builds specific source and target clients and is likely where someone would go to customize their validation process. data_validator = DataValidation(builder, source_config, target_config, result_handler=None, verbose=False) """ class DataValidation(object): def __init__( self, config, validation_builder=None, schema_validator=None, result_handler=None, verbose=False, source_client: ibis.backends.base.BaseBackend = None, target_client: ibis.backends.base.BaseBackend = None, ): """Initialize a DataValidation client Args: config (dict): The validation config used for the comparison. validation_builder (ValidationBuilder): Optional instance of a ValidationBuilder. schema_validator (SchemaValidation): Optional instance of a SchemaValidation. result_handler (ResultHandler): Optional instance of as ResultHandler client. verbose (bool): If verbose, the Data Validation client will print the queries run. source_client: Optional client to avoid unnecessary connections, target_client: Optional client to avoid unnecessary connections, """ self.verbose = verbose self._fresh_connections = not bool(source_client and target_client) # Data Client Management self.config = config self.config_manager = ConfigManager( config, source_client=source_client, target_client=target_client, verbose=self.verbose, ) self.run_metadata = metadata.RunMetadata() self.run_metadata.labels = self.config_manager.labels # Use a generated uuid for the run_id if None was supplied via config self.run_metadata.run_id = self.config_manager.run_id or str(uuid.uuid4()) # Initialize Validation Builder if None was supplied self.validation_builder = validation_builder or ValidationBuilder( self.config_manager ) self.schema_validator = schema_validator or SchemaValidation( self.config_manager, run_metadata=self.run_metadata, verbose=self.verbose ) # Initialize the default Result Handler if None was supplied self.result_handler = result_handler or self.config_manager.get_result_handler() def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): if self._fresh_connections and hasattr(self, "config_manager"): self.config_manager.close_client_connections() # TODO(dhercher) we planned on shifting this to use an Execution Handler. # Leaving to to swast on the design of how this should look. def execute(self): """Execute Queries and Store Results""" # Apply random row filter before validations run if self.config_manager.use_random_rows(): util.timed_call("Random row filter", self._add_random_row_filter) # Run correct execution for the given validation type if self.config_manager.validation_type == consts.ROW_VALIDATION: grouped_fields = self.validation_builder.pop_grouped_fields() result_df = self.execute_recursive_validation( self.validation_builder, grouped_fields ) elif self.config_manager.validation_type == consts.SCHEMA_VALIDATION: """Perform only schema validation""" result_df = util.timed_call( "Schema validation", self.schema_validator.execute ) else: result_df = self._execute_validation(self.validation_builder) # Call Result Handler to Manage Results return self.result_handler.execute(result_df) def _add_random_row_filter(self): """Add random row filters to the validation builder.""" if not self.config_manager.primary_keys: raise ValueError("Primary Keys are required for Random Row Filters") # Filter for only first primary key (multi-pk filter not supported) source_pk_column = self.config_manager.primary_keys[0][ consts.CONFIG_SOURCE_COLUMN ] target_pk_column = self.config_manager.primary_keys[0][ consts.CONFIG_TARGET_COLUMN ] randomRowBuilder = RandomRowBuilder( [source_pk_column], self.config_manager.random_row_batch_size(), ) if (self.config_manager.validation_type == consts.CUSTOM_QUERY) and ( self.config_manager.custom_query_type == consts.ROW_VALIDATION.lower() ): query = randomRowBuilder.compile_custom_query( self.config_manager.source_client, self.config_manager.source_query, ) else: query = randomRowBuilder.compile( self.config_manager.source_client, self.config_manager.source_schema, self.config_manager.source_table, self.validation_builder.source_builder, ) # Check if source table's primary key is BINARY, if so then # force cast the id columns to STRING (HEX). binary_conversion_required = False if query[source_pk_column].type().is_binary(): binary_conversion_required = True query = query.mutate( **{source_pk_column: query[source_pk_column].cast("string")} ) if self.config_manager.trim_string_pks(): query = query.mutate(**{source_pk_column: query[source_pk_column].rstrip()}) random_rows = self.config_manager.source_client.execute(query) if len(random_rows) == 0: return random_values = list(random_rows[source_pk_column]) if binary_conversion_required: # For binary ids we have a list of hex strings for our IN list. # Each of these needs to be cast back to binary. random_values = [ibis.literal(_).cast("binary") for _ in random_values] filter_field = { consts.CONFIG_TYPE: consts.FILTER_TYPE_ISIN, consts.CONFIG_FILTER_SOURCE_COLUMN: source_pk_column, consts.CONFIG_FILTER_SOURCE_VALUE: random_values, consts.CONFIG_FILTER_TARGET_COLUMN: target_pk_column, consts.CONFIG_FILTER_TARGET_VALUE: random_values, } self.validation_builder.add_filter(filter_field) def query_too_large(self, rows_df, grouped_fields): """Return bool to dictate if another level of recursion would create a too large result set. Rules to define too large are: - If any grouped fields remain, return False. (assumes user added logical sized groups) - Else, if next group size is larger than the limit, return True. - Finally return False if no covered case occured. """ if len(grouped_fields) > 1: return False try: count_df = rows_df[ rows_df[consts.AGGREGATION_TYPE] == consts.CONFIG_TYPE_COUNT ] for row in count_df.to_dict(orient="row"): recursive_query_size = max( float(row[consts.SOURCE_AGG_VALUE]), float(row[consts.TARGET_AGG_VALUE]), ) if recursive_query_size > self.config_manager.max_recursive_query_size: logging.warning("Query result is too large for recursion: %s", row) return True except Exception: logging.warning("Recursive values could not be cast to float.") return False return False def execute_recursive_validation(self, validation_builder, grouped_fields): """Recursive execution for Row validations. This method executes aggregate queries, such as sum-of-hashes, on the source and target tables. Where they differ, add to the GROUP BY clause recursively until the individual row differences can be identified. """ past_results = [] if len(grouped_fields) > 0: validation_builder.add_query_group(grouped_fields[0]) result_df = self._execute_validation(validation_builder) for grouped_key in result_df[consts.GROUP_BY_COLUMNS].unique(): # Validations are viewed separtely, but queried together. # We must treat them as a single item which failed or succeeded. group_suceeded = True grouped_key_df = result_df[ result_df[consts.GROUP_BY_COLUMNS] == grouped_key ] if self.query_too_large(grouped_key_df, grouped_fields): past_results.append(grouped_key_df) continue for row in grouped_key_df.to_dict(orient="row"): if row[consts.SOURCE_AGG_VALUE] == row[consts.TARGET_AGG_VALUE]: continue else: group_suceeded = False break if group_suceeded: past_results.append(grouped_key_df) else: recursive_validation_builder = validation_builder.clone() self._add_recursive_validation_filter( recursive_validation_builder, row ) past_results.append( self.execute_recursive_validation( recursive_validation_builder, grouped_fields[1:] ) ) elif self.config_manager.primary_keys and len(grouped_fields) == 0: past_results.append(self._execute_validation(validation_builder)) # elif self.config_manager.primary_keys: # validation_builder.add_config_query_groups(self.config_manager.primary_keys) # validation_builder.add_config_query_groups(grouped_fields) else: warnings.warn( "WARNING: No Primary Keys Suppplied in Row Validation", UserWarning ) return None return pandas.concat(past_results) def _add_recursive_validation_filter(self, validation_builder, row): """Return ValidationBuilder Configured for Next Recursive Search""" group_by_columns = json.loads(row[consts.GROUP_BY_COLUMNS]) for alias, value in group_by_columns.items(): filter_field = { consts.CONFIG_TYPE: consts.FILTER_TYPE_EQUALS, consts.CONFIG_FILTER_SOURCE_COLUMN: validation_builder.get_grouped_alias_source_column( alias ), consts.CONFIG_FILTER_SOURCE_VALUE: value, consts.CONFIG_FILTER_TARGET_COLUMN: validation_builder.get_grouped_alias_target_column( alias ), consts.CONFIG_FILTER_TARGET_VALUE: value, } validation_builder.add_filter(filter_field) def _execute_validation(self, validation_builder): """Execute Against a Supplied Validation Builder""" self.run_metadata.validations = validation_builder.get_metadata() source_query = validation_builder.get_source_query() target_query = validation_builder.get_target_query() join_on_fields = ( set(validation_builder.get_primary_keys()) if (self.config_manager.validation_type == consts.ROW_VALIDATION) or ( self.config_manager.validation_type == consts.CUSTOM_QUERY and self.config_manager.custom_query_type == "row" ) else set(validation_builder.get_group_aliases()) ) # If row validation from YAML, compare source and target agg values is_value_comparison = ( self.config_manager.validation_type == consts.ROW_VALIDATION or ( self.config_manager.validation_type == consts.CUSTOM_QUERY and self.config_manager.custom_query_type == "row" ) ) futures = [] with ThreadPoolExecutor() as executor: # Submit the two query network calls concurrently futures.append( executor.submit( util.timed_call, "Source query", self.config_manager.source_client.execute, source_query, ) ) futures.append( executor.submit( util.timed_call, "Target query", self.config_manager.target_client.execute, target_query, ) ) source_df = futures[0].result() target_df = futures[1].result() try: result_df = util.timed_call( "Generate report", combiner.generate_report, self.run_metadata, source_df, target_df, join_on_fields=join_on_fields, is_value_comparison=is_value_comparison, verbose=self.verbose, ) except Exception as e: if self.verbose: logging.error("-- ** Logging Source DF ** --") logging.error(source_df.dtypes) logging.error(source_df) logging.error("-- ** Logging Target DF ** --") logging.error(target_df.dtypes) logging.error(target_df) raise e return result_df def combine_data(self, source_df, target_df, join_on_fields): """TODO: Return List of Dictionaries""" # Clean Data to Standardize if join_on_fields: df = source_df.merge( target_df, how="outer", on=join_on_fields, suffixes=(consts.INPUT_SUFFIX, consts.OUTPUT_SUFFIX), ) else: df = source_df.join( target_df, how="outer", lsuffix=consts.INPUT_SUFFIX, rsuffix=consts.OUTPUT_SUFFIX, ) return df