data_validation/config_manager.py (998 lines of code) (raw):

# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import logging import string import random from typing import TYPE_CHECKING, Dict, List, Optional, Union, Tuple import ibis.expr.datatypes as dt import yaml from data_validation import clients, consts, gcs_helper, state_manager from data_validation.result_handlers.factory import build_result_handler from data_validation.validation_builder import ValidationBuilder if TYPE_CHECKING: import ibis.expr.types.Table class ConfigManager(object): _config: dict = None _source_conn = None _target_conn = None _state_manager = None source_client = None target_client = None def __init__(self, config, source_client=None, target_client=None, verbose=False): """Initialize a ConfigManager client which supplies the source and target queries to run. Args: config (Dict): The Validation config supplied source_client (IbisClient): The Ibis client for the source DB target_client (IbisClient): The Ibis client for the target DB verbose (Bool): If verbose, the Data Validation client will print queries run google_credentials (google.auth.credentials.Credentials): Explicit credentials to use in case default credentials aren't working properly. """ self._state_manager = state_manager.StateManager() self._config = config self.source_client = source_client or clients.get_data_client( self.get_source_connection() ) self.target_client = target_client or clients.get_data_client( self.get_target_connection() ) self.verbose = verbose if self.validation_type not in consts.CONFIG_TYPES: raise ValueError(f"Unknown Configuration Type: {self.validation_type}") self._comparison_max_col_length = None # For some engines we need to know the actual raw data type rather than the Ibis canonical type. self._source_raw_data_types = None self._target_raw_data_types = None @property def config(self): """Return config object.""" return self._config def get_source_connection(self): """Return source connection object.""" if not self._source_conn: if self._config.get(consts.CONFIG_SOURCE_CONN): self._source_conn = self._config.get(consts.CONFIG_SOURCE_CONN) else: conn_name = self._config.get(consts.CONFIG_SOURCE_CONN_NAME) self._source_conn = self._state_manager.get_connection_config(conn_name) return self._source_conn def get_target_connection(self): """Return target connection object.""" if not self._target_conn: if self._config.get(consts.CONFIG_TARGET_CONN): self._target_conn = self._config.get(consts.CONFIG_TARGET_CONN) else: conn_name = self._config.get(consts.CONFIG_TARGET_CONN_NAME) self._target_conn = self._state_manager.get_connection_config(conn_name) return self._target_conn def get_source_raw_data_types(self) -> Dict[str, Tuple]: """Return raw data type information from source system. The raw data type is the source/target engine type, for example it might be "NCLOB" or "char" when the Ibis type simply states "string". The data is cached in state when fetched for the first time. The retuen value is keyed on the casefolded column name and the tuple is the remaining 6 elements of the DB API cursor description specification.""" if self._source_raw_data_types is None: if hasattr(self.source_client, "raw_column_metadata"): raw_data_types = self.source_client.raw_column_metadata( database=self.source_schema, table=self.source_table, query=self.source_query, ) self._source_raw_data_types = { _[0].casefold(): _[1:] for _ in raw_data_types } else: self._source_raw_data_types = {} return self._source_raw_data_types def get_target_raw_data_types(self) -> Dict[str, Tuple]: """Return raw data type information from target system. The raw data type is the source/target engine type, for example it might be "NCLOB" or "char" when the Ibis type simply states "string". The data is cached in state when fetched for the first time. The retuen value is keyed on the casefolded column name and the tuple is the remaining 6 elements of the DB API cursor description specification.""" if self._target_raw_data_types is None: if hasattr(self.target_client, "raw_column_metadata"): raw_data_types = self.target_client.raw_column_metadata( database=self.target_schema, table=self.target_table, query=self.target_query, ) self._target_raw_data_types = { _[0].casefold(): _[1:] for _ in raw_data_types } else: self._target_raw_data_types = {} return self._target_raw_data_types def close_client_connections(self): """Attempt to clean up any source/target connections, based on the client types. Not all clients are covered here, we at least have Oracle and PostgreSQL for which we have seen connections being accumulated. https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/1195 """ try: if self.source_client and self.source_client.name in ("oracle", "postgres"): self.source_client.con.dispose() if self.target_client and self.target_client.name in ("oracle", "postgres"): self.target_client.con.dispose() except Exception as exc: # No need to reraise, we can silently fail if exiting throws up an issue. logging.warning("Exception closing connections: %s", str(exc)) @property def validation_type(self): """Return string validation type (Column|Schema).""" return self._config[consts.CONFIG_TYPE] def use_random_rows(self): """Return if the validation should use a random row filter.""" return self._config.get(consts.CONFIG_USE_RANDOM_ROWS) or False def random_row_batch_size(self): """Return batch size for random row filter.""" return int( self._config.get(consts.CONFIG_RANDOM_ROW_BATCH_SIZE) or consts.DEFAULT_NUM_RANDOM_ROWS ) def get_random_row_batch_size(self): """Return number of random rows or None.""" return self.random_row_batch_size() if self.use_random_rows() else None def trim_string_pks(self): """Return if the validation should trim string primary keys.""" return self._config.get(consts.CONFIG_TRIM_STRING_PKS) or False def case_insensitive_match(self): """Return if the validation should perform a case insensitive match.""" return self._config.get(consts.CONFIG_CASE_INSENSITIVE_MATCH) or False @property def max_recursive_query_size(self): """Return Aggregates from Config""" return self._config.get(consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE, 50000) @property def aggregates(self): """Return Aggregates from Config""" return self._config.get(consts.CONFIG_AGGREGATES, []) def append_aggregates(self, aggregate_configs): """Append aggregate configs to existing config.""" self._config[consts.CONFIG_AGGREGATES] = self.aggregates + aggregate_configs @property def calculated_fields(self): return self._config.get(consts.CONFIG_CALCULATED_FIELDS, []) def append_calculated_fields(self, calculated_configs): self._config[consts.CONFIG_CALCULATED_FIELDS] = ( self.calculated_fields + calculated_configs ) @property def query_groups(self): """Return Query Groups from Config""" return self._config.get(consts.CONFIG_GROUPED_COLUMNS, []) def append_query_groups(self, grouped_column_configs): """Append grouped configs to existing config.""" self._config[consts.CONFIG_GROUPED_COLUMNS] = ( self.query_groups + grouped_column_configs ) @property def custom_query_type(self): """Return custom query type from config""" return self._config.get(consts.CONFIG_CUSTOM_QUERY_TYPE, "") def append_custom_query_type(self, custom_query_type): """Append custom query type config to existing config.""" self._config[consts.CONFIG_CUSTOM_QUERY_TYPE] = ( self.custom_query_type + custom_query_type ) @property def source_query_file(self): """Return SQL Query File from Config""" return self._config.get(consts.CONFIG_SOURCE_QUERY_FILE, []) def append_source_query_file(self, query_file_configs): """Append grouped configs to existing config.""" self._config[consts.CONFIG_SOURCE_QUERY_FILE] = ( self.source_query_file + query_file_configs ) @property def target_query_file(self): """Return SQL Query File from Config""" return self._config.get(consts.CONFIG_TARGET_QUERY_FILE, []) def append_target_query_file(self, query_file_configs): """Append grouped configs to existing config.""" self._config[consts.CONFIG_TARGET_QUERY_FILE] = ( self.target_query_file + query_file_configs ) @property def primary_keys(self): """Return Primary keys from Config""" return self._config.get(consts.CONFIG_PRIMARY_KEYS, []) def append_primary_keys(self, primary_key_configs): """Append primary key configs to existing config.""" self._config[consts.CONFIG_PRIMARY_KEYS] = ( self.primary_keys + primary_key_configs ) def get_primary_keys_list(self): """Return list of primary key column names""" return [key[consts.CONFIG_SOURCE_COLUMN] for key in self.primary_keys] @property def comparison_fields(self): """Return fields from Config""" return self._config.get(consts.CONFIG_COMPARISON_FIELDS, []) def append_comparison_fields(self, field_configs): """Append field configs to existing config.""" self._config[consts.CONFIG_COMPARISON_FIELDS] = ( self.comparison_fields + field_configs ) @property def concat(self): """Return field from Config""" return self._config.get(consts.CONFIG_ROW_CONCAT, []) @property def hash(self): """Return field from Config""" return self._config.get(consts.CONFIG_ROW_HASH, []) @property def run_id(self): """Return field from Config""" return self._config.get(consts.CONFIG_RUN_ID, None) @property def filters(self): """Return Filters from Config""" return self._config.get(consts.CONFIG_FILTERS, []) @property def source_schema(self): """Return string value of source schema.""" if self.source_client._source_type == "FileSystem": return None return self._config.get(consts.CONFIG_SCHEMA_NAME, None) @property def source_table(self): """Return string value of source table.""" return self._config[consts.CONFIG_TABLE_NAME] @property def target_schema(self): """Return string value of target schema.""" if self.target_client._source_type == "FileSystem": return None return self._config.get(consts.CONFIG_TARGET_SCHEMA_NAME, self.source_schema) @property def target_table(self): """Return string value of target table.""" return self._config.get( consts.CONFIG_TARGET_TABLE_NAME, self._config[consts.CONFIG_TABLE_NAME] ) @property def full_target_table(self): """Return string value of fully qualified target table.""" if self.target_schema: return self.target_schema + "." + self.target_table else: return self.target_table @property def full_source_table(self): """Return string value of target table.""" if self.source_table and self.source_schema: return self.source_schema + "." + self.source_table elif self.source_table: return self.source_table else: return f"custom.{''.join(random.choice(string.ascii_lowercase) for _ in range(5))}" @property def labels(self): """Return labels.""" return self._config.get(consts.CONFIG_LABELS, []) @property def result_handler_config(self): """Return int limit for query executions.""" return self._config.get(consts.CONFIG_RESULT_HANDLER) or {} @property def query_limit(self): """Return int limit for query executions.""" return self._config.get(consts.CONFIG_LIMIT) @property def threshold(self): """Return threshold from Config""" return self._config.get(consts.CONFIG_THRESHOLD, 0.0) @property def source_query(self): return self._config.get(consts.CONFIG_SOURCE_QUERY, None) def append_source_query(self, source_query): self._config["source_query"] = source_query @property def target_query(self): return self._config.get(consts.CONFIG_TARGET_QUERY, None) def append_target_query(self, target_query): self._config["target_query"] = target_query @property def exclusion_columns(self): """Return the exclusion columns from Config""" return self._config.get(consts.CONFIG_EXCLUSION_COLUMNS, []) @property def allow_list(self): """Return the allow_list from Config""" return self._config.get(consts.CONFIG_ALLOW_LIST, "") @property def filter_status(self): """Return filter status list from Config""" return self._config.get(consts.CONFIG_FILTER_STATUS, None) def append_exclusion_columns(self, column_configs): """Append exclusion columns to existing config.""" self._config[consts.CONFIG_EXCLUSION_COLUMNS] = ( self.exclusion_columns + column_configs ) def append_allow_list( self, allow_list: Union[str, None], allow_list_file: Union[str, None] ): """Append datatype allow_list to existing config.""" full_allow_list = [] if allow_list: allow_list = allow_list.replace(" ", "") full_allow_list.append(allow_list) if allow_list_file: try: allow_list_yaml = gcs_helper.read_file(allow_list_file) except FileNotFoundError as e: raise ValueError( "Cannot locate --allow-list-file: {allow_list_file}" ) from e allow_list_dict = yaml.safe_load(allow_list_yaml) full_allow_list.append( ",".join([f"{_[0]}:{_[1]}" for _ in allow_list_dict.items()]) ) self._config[consts.CONFIG_ALLOW_LIST] = ",".join(full_allow_list) def get_source_ibis_table(self): """Return IbisTable from source.""" if not hasattr(self, "_source_ibis_table"): self._source_ibis_table = clients.get_ibis_table( self.source_client, self.source_schema, self.source_table ) return self._source_ibis_table def get_source_ibis_table_from_query(self): """Return IbisTable from source.""" if not hasattr(self, "_source_ibis_table"): self._source_ibis_table = clients.get_ibis_query( self.source_client, self.source_query ) return self._source_ibis_table def get_source_ibis_calculated_table(self, depth=None): """Return mutated IbisTable from source depth: Int the depth of subquery requested""" if self.validation_type == consts.CUSTOM_QUERY: table = self.get_source_ibis_table_from_query() else: table = self.get_source_ibis_table() vb = ValidationBuilder(self) calculated_table = table.mutate( vb.source_builder.compile_calculated_fields(table, n=depth) ) return calculated_table def get_target_ibis_table(self): """Return IbisTable from target.""" if not hasattr(self, "_target_ibis_table"): self._target_ibis_table = clients.get_ibis_table( self.target_client, self.target_schema, self.target_table ) return self._target_ibis_table def get_target_ibis_table_from_query(self): """Return IbisTable from source.""" if not hasattr(self, "_target_ibis_table"): self._target_ibis_table = clients.get_ibis_query( self.target_client, self.target_query ) return self._target_ibis_table def get_target_ibis_calculated_table(self, depth=None): """Return mutated IbisTable from target n: Int the depth of subquery requested""" if self.validation_type == consts.CUSTOM_QUERY: table = self.get_target_ibis_table_from_query() else: table = self.get_target_ibis_table() vb = ValidationBuilder(self) calculated_table = table.mutate( vb.target_builder.compile_calculated_fields(table, n=depth) ) return calculated_table def get_yaml_validation_block(self): """Return Dict object formatted for a Yaml file.""" config = copy.deepcopy(self.config) config.pop(consts.CONFIG_SOURCE_CONN, None) config.pop(consts.CONFIG_TARGET_CONN, None) config.pop(consts.CONFIG_SOURCE_CONN_NAME, None) config.pop(consts.CONFIG_TARGET_CONN_NAME, None) config.pop(consts.CONFIG_RESULT_HANDLER, None) return config def get_result_handler(self): """Return ResultHandler instance from supplied config.""" return build_result_handler( self.result_handler_config, self.config[consts.CONFIG_TYPE], self.filter_status, text_format=self._config.get( consts.CONFIG_FORMAT, consts.FORMAT_TYPE_TABLE ), ) @staticmethod def build_config_manager( config_type, source_conn_name, target_conn_name, table_obj, labels, threshold, format, use_random_rows=None, random_row_batch_size=None, source_client=None, target_client=None, result_handler_config=None, filter_config=None, filter_status=None, trim_string_pks=None, case_insensitive_match=None, concat=None, hash=None, run_id=None, verbose=False, ): if isinstance(filter_config, dict): filter_config = [filter_config] """Return a ConfigManager instance with available config.""" config = { consts.CONFIG_TYPE: config_type, consts.CONFIG_SOURCE_CONN_NAME: source_conn_name, consts.CONFIG_TARGET_CONN_NAME: target_conn_name, consts.CONFIG_TABLE_NAME: table_obj.get(consts.CONFIG_TABLE_NAME, None), consts.CONFIG_SCHEMA_NAME: table_obj.get(consts.CONFIG_SCHEMA_NAME, None), consts.CONFIG_TARGET_SCHEMA_NAME: table_obj.get( consts.CONFIG_TARGET_SCHEMA_NAME, table_obj.get(consts.CONFIG_SCHEMA_NAME, None), ), consts.CONFIG_TARGET_TABLE_NAME: table_obj.get( consts.CONFIG_TARGET_TABLE_NAME, table_obj.get(consts.CONFIG_TABLE_NAME, None), ), consts.CONFIG_LABELS: labels, consts.CONFIG_THRESHOLD: threshold, consts.CONFIG_FORMAT: format, consts.CONFIG_RESULT_HANDLER: result_handler_config, consts.CONFIG_FILTERS: filter_config, consts.CONFIG_USE_RANDOM_ROWS: use_random_rows, consts.CONFIG_RANDOM_ROW_BATCH_SIZE: random_row_batch_size, consts.CONFIG_FILTER_STATUS: filter_status, consts.CONFIG_TRIM_STRING_PKS: trim_string_pks, consts.CONFIG_CASE_INSENSITIVE_MATCH: case_insensitive_match, consts.CONFIG_ROW_CONCAT: concat, consts.CONFIG_ROW_HASH: hash, consts.CONFIG_RUN_ID: run_id, } return ConfigManager( config, source_client=source_client, target_client=target_client, verbose=verbose, ) def add_rstrip_to_comp_fields(self, comparison_fields: List[str]) -> List[str]: """As per #1190, add an rstrip calculated field for Teradata string comparison fields. Parameters: comparison_fields: List[str] of comparison field columns Returns: comp_fields_with_aliases: List[str] of comparison field columns with rstrip aliases """ source_table = self.get_source_ibis_calculated_table() target_table = self.get_target_ibis_calculated_table() source_table_schema = {k: v for k, v in source_table.schema().items()} target_table_schema = {k: v for k, v in target_table.schema().items()} casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns} comp_fields_with_aliases = [] calculated_configs = [] for field in comparison_fields: if field.casefold() not in casefold_source_columns: raise ValueError(f"Column DNE in source: {field}") if field.casefold() not in casefold_target_columns: raise ValueError(f"Column DNE in target: {field}") source_ibis_type = source_table[ casefold_source_columns[field.casefold()] ].type() target_ibis_type = target_table[ casefold_target_columns[field.casefold()] ].type() if ( source_ibis_type.is_string() or target_ibis_type.is_string() ) and not self._comp_field_cast( # Do not add rstrip if the column is a bool or UUID hiding in a string. source_table_schema, target_table_schema, field, ): logging.info( f"Adding rtrim() to string comparison field `{field.casefold()}` due to Teradata CHAR padding." ) alias = f"rstrip__{field.casefold()}" calculated_configs.append( self.build_config_calculated_fields( [casefold_source_columns[field.casefold()]], [casefold_target_columns[field.casefold()]], consts.CALC_FIELD_RSTRIP, alias, 0, ) ) comp_fields_with_aliases.append(alias) else: comp_fields_with_aliases.append(field) self.append_calculated_fields(calculated_configs) return comp_fields_with_aliases def _comp_field_cast( self, source_table_schema: dict, target_table_schema: dict, field: str ) -> str: # We check below if the field exists because sometimes it is a computed name # like "concat__all" which is not in the real table. source_type = ( source_table_schema[field] if field in source_table_schema else None ) target_type = ( target_table_schema[field] if field in target_table_schema else None ) if self._is_bool(source_type, target_type): return "bool" elif self._is_uuid(source_type, target_type): return consts.CONFIG_CAST_UUID_STRING return None def _is_bool( self, source_type: Union[str, dt.DataType], target_type: Union[str, dt.DataType] ) -> bool: """Returns whether column is BOOLEAN based on either source or target data type. We do this because some engines don't have a BOOLEAN type, therefore BOOLEAN on one side means both sides need to be BOOLEAN aware.""" if isinstance(source_type, str): return any(_ in ["bool", "!bool"] for _ in [source_type, target_type]) else: return bool( isinstance(source_type, dt.Boolean) or isinstance(target_type, dt.Boolean) ) def _is_uuid( self, source_type: Union[str, dt.DataType], target_type: Union[str, dt.DataType] ) -> bool: """Returns whether column is UUID based on either source or target data type. We do this because some engines don't have a UUID type, therefore UUID on one side means both sides are UUID. i.e. we use any() not all().""" if isinstance(source_type, str): return any(_ in ["uuid", "!uuid"] for _ in [source_type, target_type]) else: return bool( isinstance(source_type, dt.UUID) or isinstance(target_type, dt.UUID) ) def build_config_comparison_fields(self, fields, depth=None): """Return list of field config objects.""" field_configs = [] source_table = self.get_source_ibis_calculated_table() target_table = self.get_target_ibis_calculated_table() source_table_schema = {k: v for k, v in source_table.schema().items()} target_table_schema = {k: v for k, v in target_table.schema().items()} casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns} for field in fields: cast_type = self._comp_field_cast( source_table_schema, target_table_schema, field ) column_config = { consts.CONFIG_SOURCE_COLUMN: casefold_source_columns.get( field.casefold(), field ), consts.CONFIG_TARGET_COLUMN: casefold_target_columns.get( field.casefold(), field ), consts.CONFIG_FIELD_ALIAS: field, consts.CONFIG_CAST: cast_type, } field_configs.append(column_config) return field_configs def build_column_configs(self, columns): """Return list of column config objects.""" column_configs = [] source_table = self.get_source_ibis_calculated_table() target_table = self.get_target_ibis_calculated_table() casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns} for column in columns: if column.casefold() not in casefold_source_columns: raise ValueError(f"Column DNE in source: {column}") if column.casefold() not in casefold_target_columns: raise ValueError(f"Column DNE in target: {column}") source_ibis_type = source_table[ casefold_source_columns[column.casefold()] ].type() target_ibis_type = target_table[ casefold_target_columns[column.casefold()] ].type() cast_type = self._key_column_needs_casting_to_string( source_ibis_type, target_ibis_type ) column_config = { consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column.casefold()], consts.CONFIG_TARGET_COLUMN: casefold_target_columns[column.casefold()], consts.CONFIG_FIELD_ALIAS: column, consts.CONFIG_CAST: cast_type, } column_configs.append(column_config) return column_configs def build_config_count_aggregate(self): """Return dict aggregate for COUNT(*).""" aggregate_config = { consts.CONFIG_SOURCE_COLUMN: None, consts.CONFIG_TARGET_COLUMN: None, consts.CONFIG_FIELD_ALIAS: "count", consts.CONFIG_TYPE: "count", } return aggregate_config def _prefix_calc_col_name( self, column_name: str, prefix: str, column_number: int ) -> str: """Prefix a column name but protect final string from overflowing SQL engine identifier length limit.""" new_name = f"{prefix}__{column_name}" if len(new_name) > self._get_comparison_max_col_length(): # Use an abstract name for the calculated column to avoid composing invalid SQL. new_name = f"{prefix}__dvt_calc_col_{column_number}" return new_name def build_and_append_pre_agg_calc_config( self, source_column, target_column, calc_func, column_position, cast_type: str = None, depth: int = 0, ): """Create calculated field config used as a pre-aggregation step. Appends to calculated fields if does not already exist and returns created config.""" calculated_config = { consts.CONFIG_CALCULATED_SOURCE_COLUMNS: [source_column], consts.CONFIG_CALCULATED_TARGET_COLUMNS: [target_column], consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name( source_column, calc_func, column_position ), consts.CONFIG_TYPE: calc_func, consts.CONFIG_DEPTH: depth, } if calc_func == consts.CONFIG_CAST and cast_type is not None: calculated_config[consts.CONFIG_DEFAULT_CAST] = cast_type calculated_config[consts.CONFIG_FIELD_ALIAS] = self._prefix_calc_col_name( source_column, f"{calc_func}_{cast_type}", column_position ) existing_calc_fields = [ config[consts.CONFIG_FIELD_ALIAS] for config in self.calculated_fields ] if calculated_config[consts.CONFIG_FIELD_ALIAS] not in existing_calc_fields: self.append_calculated_fields([calculated_config]) return calculated_config def append_pre_agg_calc_field( self, source_column: str, target_column: str, agg_type: str, column_type: str, target_column_type: str, column_position: int, ) -> dict: """Append calculated field for length() or epoch_seconds(timestamp) for preprocessing before column validation aggregation.""" depth = 0 cast_type = None final_cast_type = None if any(_ in ["json", "!json"] for _ in [column_type, target_column_type]): # JSON data which needs casting to string before we apply a length function. pre_calculated_config = self.build_and_append_pre_agg_calc_config( source_column, target_column, consts.CONFIG_CAST, column_position, cast_type="string", depth=depth, ) source_column = target_column = pre_calculated_config[ consts.CONFIG_FIELD_ALIAS ] depth = 1 calc_func = consts.CALC_FIELD_LENGTH elif column_type in ["string", "!string"]: calc_func = consts.CALC_FIELD_LENGTH elif self._is_uuid(column_type, target_column_type): calc_func = consts.CONFIG_CAST cast_type = consts.CONFIG_CAST_UUID_STRING elif column_type in ["binary", "!binary"]: calc_func = consts.CALC_FIELD_BYTE_LENGTH elif column_type in ["timestamp", "!timestamp", "date", "!date"]: if ( self.source_client.name == "bigquery" or self.target_client.name == "bigquery" ): pre_calculated_config = self.build_and_append_pre_agg_calc_config( source_column, target_column, consts.CONFIG_CAST, column_position, cast_type="timestamp", depth=depth, ) source_column = target_column = pre_calculated_config[ consts.CONFIG_FIELD_ALIAS ] depth = 1 calc_func = consts.CALC_FIELD_EPOCH_SECONDS if agg_type == consts.CONFIG_TYPE_SUM: # It is possible to exceed int64 when summing epoch_seconds therefore cast to string. # See issue 1391 for details. final_cast_type = "string" elif column_type == "int32" or column_type == "!int32": calc_func = consts.CONFIG_CAST cast_type = "int64" else: raise ValueError(f"Unsupported column type: {column_type}") calculated_config = self.build_and_append_pre_agg_calc_config( source_column, target_column, calc_func, column_position, cast_type=cast_type, depth=depth, ) aggregate_config = { consts.CONFIG_SOURCE_COLUMN: f"{calculated_config[consts.CONFIG_FIELD_ALIAS]}", consts.CONFIG_TARGET_COLUMN: f"{calculated_config[consts.CONFIG_FIELD_ALIAS]}", consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name( calculated_config[consts.CONFIG_FIELD_ALIAS], f"{agg_type}", column_position, ), consts.CONFIG_TYPE: agg_type, } if final_cast_type: # Adding to dict this way to avoid adding a lot of empty cast attributes. aggregate_config[consts.CONFIG_CAST] = final_cast_type return aggregate_config def _decimal_column_too_big_for_pandas( self, source_column_ibis_type: dt.DataType, target_column_ibis_type: dt.DataType, margin: int = 0, ) -> bool: """Identifies numeric columns that will cause problems in a Pandas Dataframe. i.e. are of greater precision than a 64bit int/real can hold. margin: Allows us to lower the precision threshold. This is helpful when summing column values that are okay by themselves but cumulativaly could overflow a 64bit value. """ return bool( ( (isinstance(source_column_ibis_type, dt.Int64) and margin > 0) or ( isinstance(source_column_ibis_type, dt.Decimal) and ( source_column_ibis_type.precision is None or source_column_ibis_type.precision > (18 - margin) ) ) ) and ( (isinstance(target_column_ibis_type, dt.Int64) and margin > 0) or ( isinstance(target_column_ibis_type, dt.Decimal) and ( target_column_ibis_type.precision is None or target_column_ibis_type.precision > (18 - margin) ) ) ) ) def _key_column_needs_casting_to_string( self, source_column_ibis_type: dt.DataType, target_column_ibis_type: dt.DataType, ) -> str: """Return a string cast if the datatype combination requires it, otherwise None.""" if self._is_uuid(source_column_ibis_type, target_column_ibis_type): # This needs to come before binary check because Oracle # stores UUIDs (GUID) in binary columns. return consts.CONFIG_CAST_UUID_STRING elif ( self._decimal_column_too_big_for_pandas( source_column_ibis_type, target_column_ibis_type ) or isinstance(source_column_ibis_type, dt.Binary) or isinstance(target_column_ibis_type, dt.Binary) ): return "string" else: return None def _type_is_supported_for_agg_validation( self, source_type: str, target_type: str, supported_types: list ) -> bool: if self._is_uuid(source_type, target_type): return bool("uuid" in supported_types) return bool(source_type in supported_types and target_type in supported_types) def build_config_column_aggregates( self, agg_type, arg_value, exclude_cols, supported_types, cast_to_bigint=False ): """Return list of aggregate objects of given agg_type.""" def require_pre_agg_calc_field( column_type: str, target_column_type: str, agg_type: str, cast_to_bigint: bool, ) -> bool: if all( _ in ["string", "!string", "json", "!json"] for _ in [column_type, target_column_type] ): # These data types are aggregated using their lengths. return True elif self._is_uuid(column_type, target_column_type): return True elif column_type in ["binary", "!binary"]: if agg_type == "count": # Oracle BLOB is invalid for use with SQL COUNT function. # The expression below returns True if client is Oracle which # has the effect of triggering use of byte_length transformation. return bool( self.source_client.name == "oracle" or self.target_client.name == "oracle" ) else: # Convert to length for any min/max/sum on binary columns. return True elif cast_to_bigint and column_type in ["int32", "!int32"]: return True elif column_type in [ "timestamp", "!timestamp", "date", "!date", ] and agg_type in ( "sum", "avg", "bit_xor", ): # For timestamps: do not convert to epoch seconds for min/max return True return False aggregate_configs = [] source_table = self.get_source_ibis_calculated_table() target_table = self.get_target_ibis_calculated_table() casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns} if arg_value: arg_value = [x.casefold() for x in arg_value] if exclude_cols: included_cols = [ column for column in casefold_source_columns if column not in arg_value ] arg_value = included_cols if supported_types: # This mutates external supported_types, making it local as part of adding more values. supported_types = supported_types + [ "string", "!string", "timestamp", "!timestamp", "date", "!date", "binary", "!binary", ] else: if exclude_cols: raise ValueError( "Exclude columns flag cannot be present with '*' column aggregation" ) allowlist_columns = arg_value or casefold_source_columns for column_position, column in enumerate(casefold_source_columns): # Get column type and remove precision/scale attributes source_column_ibis_type = source_table[ casefold_source_columns[column] ].type() column_type = str(source_column_ibis_type).split("(")[0] target_column_ibis_type = target_table[ casefold_target_columns[column] ].type() target_column_type = str(target_column_ibis_type).split("(")[0] if column not in allowlist_columns: continue elif column not in casefold_target_columns: logging.warning( f"Skipping {agg_type} on {column} as column is not present in target table" ) continue elif supported_types and not self._type_is_supported_for_agg_validation( column_type, target_column_type, supported_types ): if self.verbose: logging.info( f"Skipping {agg_type} on {column} due to data type: {column_type}" ) continue if require_pre_agg_calc_field( column_type, target_column_type, agg_type, cast_to_bigint ): aggregate_config = self.append_pre_agg_calc_field( casefold_source_columns[column], casefold_target_columns[column], agg_type, column_type, target_column_type, column_position, ) else: aggregate_config = { consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column], consts.CONFIG_TARGET_COLUMN: casefold_target_columns[column], consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name( column, f"{agg_type}", column_position ), consts.CONFIG_TYPE: agg_type, } if self._decimal_column_too_big_for_pandas( source_column_ibis_type, target_column_ibis_type, margin=(2 if agg_type == consts.CONFIG_TYPE_SUM else 0), ): aggregate_config[consts.CONFIG_CAST] = "string" aggregate_configs.append(aggregate_config) return aggregate_configs def build_config_calculated_fields( self, source_reference: list, target_reference: list, calc_type: str, alias: str, depth: int, custom_params: Optional[dict] = None, ) -> dict: """Returns list of calculated fields""" calculated_config = { consts.CONFIG_CALCULATED_SOURCE_COLUMNS: source_reference, consts.CONFIG_CALCULATED_TARGET_COLUMNS: target_reference, consts.CONFIG_FIELD_ALIAS: alias, consts.CONFIG_TYPE: calc_type, consts.CONFIG_DEPTH: depth, } if calc_type == consts.CONFIG_CUSTOM and custom_params: calculated_config.update(custom_params) elif calc_type == consts.CONFIG_CAST and custom_params: calculated_config[consts.CONFIG_DEFAULT_CAST] = custom_params return calculated_config def _get_comparison_max_col_length(self) -> int: if not self._comparison_max_col_length: self._comparison_max_col_length = min( [ clients.get_max_column_length(self.source_client), clients.get_max_column_length(self.target_client), ] ) return self._comparison_max_col_length def _strftime_format( self, column_type: Union[dt.Date, dt.Timestamp], client ) -> str: if column_type.is_timestamp(): return "%Y-%m-%d %H:%M:%S" if clients.is_oracle_client(client): # Oracle DATE is a DateTime return "%Y-%m-%d %H:%M:%S" return "%Y-%m-%d" def _apply_base_cast_overrides( self, source_column: str, target_column: str, col_config: dict, source_table: "ibis.expr.types.Table", target_table: "ibis.expr.types.Table", ) -> dict: """Mutates col_config to contain any overrides. Also returns col_config for convenience.""" if col_config["calc_type"] != consts.CALC_FIELD_CAST: return col_config source_table_schema = {k: v for k, v in source_table.schema().items()} target_table_schema = {k: v for k, v in target_table.schema().items()} if isinstance( source_table_schema[source_column], (dt.Date, dt.Timestamp) ) and isinstance(target_table_schema[target_column], (dt.Date, dt.Timestamp)): # Use strftime rather than cast for temporal comparisons. # Pick the most permissive format across the two engines. # For example Date -> Timestamp should format both source and target as Date. fmt = min( [ self._strftime_format( source_table_schema[source_column], self.source_client ), self._strftime_format( source_table_schema[source_column], self.target_client ), ], key=len, ) col_config["calc_type"] = consts.CONFIG_CUSTOM custom_params = { "calc_params": { consts.CONFIG_CUSTOM_IBIS_EXPR: "ibis.expr.types.TemporalValue.strftime", consts.CONFIG_CUSTOM_PARAMS: [ {consts.CONFIG_CUSTOM_PARAM_FORMAT_STR: fmt} ], } } col_config.update(custom_params) elif isinstance(source_table_schema[source_column], dt.Boolean) or isinstance( target_table_schema[target_column], dt.Boolean ): custom_params = {"calc_params": consts.CONFIG_CAST_BOOL_STRING} col_config.update(custom_params) elif self._is_uuid( source_table_schema[source_column], target_table_schema[target_column] ): custom_params = {"calc_params": consts.CONFIG_CAST_UUID_STRING} col_config.update(custom_params) return col_config def _get_order_of_operations(self, calc_type: str) -> List[str]: """Return order of operations for row validation.""" order_of_operations = [ consts.CALC_FIELD_CAST, consts.CALC_FIELD_IFNULL, consts.CALC_FIELD_RSTRIP, ] if self.case_insensitive_match(): order_of_operations.append(consts.CALC_FIELD_UPPER) if calc_type == consts.CALC_FIELD_HASH: order_of_operations.extend( [consts.CALC_FIELD_CONCAT, consts.CALC_FIELD_HASH] ) elif calc_type == consts.CALC_FIELD_CONCAT: order_of_operations.append(consts.CALC_FIELD_CONCAT) return order_of_operations def _filter_columns_by_column_list( self, casefold_columns: list, col_list: list, exclude_cols: bool ) -> list: if col_list: filter_list = [_.casefold() for _ in col_list] if exclude_cols: # Exclude columns based on col_list if provided casefold_columns = { k: v for (k, v) in casefold_columns.items() if k not in filter_list } else: # Include columns based on col_list if provided casefold_columns = { k: v for (k, v) in casefold_columns.items() if k in filter_list } elif exclude_cols: raise ValueError( "Exclude columns flag cannot be present with column list '*'" ) return casefold_columns def build_dependent_aliases( self, calc_type: str, col_list=None, exclude_cols=False ) -> List[Dict]: """This is a utility function for determining the required depth of all fields""" source_table = self.get_source_ibis_calculated_table() target_table = self.get_target_ibis_calculated_table() casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns} casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns} casefold_source_columns = self._filter_columns_by_column_list( casefold_source_columns, col_list, exclude_cols ) casefold_target_columns = self._filter_columns_by_column_list( casefold_target_columns, col_list, exclude_cols ) column_aliases = {} col_names = [] for i, calc in enumerate(self._get_order_of_operations(calc_type)): if i == 0: previous_level = [x for x in casefold_source_columns.keys()] else: previous_level = [k for k, v in column_aliases.items() if v == i - 1] if calc in [consts.CALC_FIELD_CONCAT, consts.CALC_FIELD_HASH]: col = {} col["source_reference"] = previous_level col["target_reference"] = previous_level col["name"] = f"{calc}__all" col["calc_type"] = calc col["depth"] = i name = col["name"] # need to capture all aliases at the previous level. probably name concat__all column_aliases[name] = i col_names.append(col) else: # This needs to be the previous manifest of columns for j, column in enumerate(previous_level): col = {} col["source_reference"] = [column] col["target_reference"] = [column] col["name"] = self._prefix_calc_col_name(column, calc, j) col["calc_type"] = calc col["depth"] = i if i == 0: # If depth 0, get raw column name with correct casing source_column = casefold_source_columns[column] target_column = casefold_target_columns[column] col["source_reference"] = [source_column] col["target_reference"] = [target_column] # If we are casting the base column (i == 0) then apply any # datatype specific overrides. col = self._apply_base_cast_overrides( source_column, target_column, col, source_table, target_table, ) name = col["name"] column_aliases[name] = i col_names.append(col) return col_names def build_comp_fields(self, col_list: list, exclude_cols: bool = False) -> dict: """This is a utility function processing comp-fields values like we do for hash/concat.""" source_table = self.get_source_ibis_calculated_table() casefold_source_columns = {_.casefold(): str(_) for _ in source_table.columns} casefold_source_columns = self._filter_columns_by_column_list( casefold_source_columns, col_list, exclude_cols ) return casefold_source_columns def auto_list_primary_keys(self) -> list: """Returns a list of primary key columns based on the source/target table. If neither source nor target systems have a primary key defined then [] is returned. """ assert ( self.validation_type != consts.CUSTOM_QUERY ), "Custom query validations should not be able to reach this method" primary_keys = self.source_client.list_primary_key_columns( self.source_schema, self.source_table ) if not primary_keys: primary_keys = self.target_client.list_primary_key_columns( self.target_schema, self.target_table ) return primary_keys or []