data_validation/validation_builder.py (334 lines of code) (raw):
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import logging
from typing import TYPE_CHECKING
from data_validation import consts, metadata
from data_validation.clients import get_max_in_list_size
from data_validation.query_builder.query_builder import (
AggregateField,
CalculatedField,
ComparisonField,
FilterField,
GroupedField,
QueryBuilder,
)
if TYPE_CHECKING:
import ibis
def list_to_sublists(id_list: list, max_size: int) -> list:
"""Return a list of items as a list of lists based on a max list length of max_size."""
return [id_list[_ : _ + max_size] for _ in range(0, len(id_list), max_size)]
class ValidationBuilder(object):
def __init__(self, config_manager):
"""Initialize a ValidationBuilder client which supplies the
source and target queries to run.
Args:
config_manager (ConfigManager): A validation ConfigManager instance
source_client (IbisClient): The Ibis client for the source DB
target_client (IbisClient): The Ibis client for the target DB
verbose (Bool): If verbose, the Data Validation client will print queries run
"""
self._metadata = {}
self.config_manager = config_manager
self.verbose = self.config_manager.verbose
self.validation_type = self.config_manager.validation_type
self.source_client = self.config_manager.source_client
self.target_client = self.config_manager.target_client
self.source_builder = self.get_query_builder(self.validation_type)
self.target_builder = self.get_query_builder(self.validation_type)
self.primary_keys = {}
self.group_aliases = {}
self.calculated_aliases = {}
self.comparison_fields = {}
self.add_config_aggregates()
self.add_config_query_groups()
self.add_config_calculated_fields()
self.add_comparison_fields()
self.add_config_filters()
self.add_primary_keys()
self.add_query_limit()
def _construct_isin_filter(
self, client, column_name: str, in_list: list
) -> FilterField:
"""Return a FilterField object that is either isin(...) or (isin(...) OR isin(...) OR...)."""
max_in_list_size = get_max_in_list_size(
client,
in_list_over_expressions=bool(
in_list and getattr(in_list[0], "cast", None)
),
)
if max_in_list_size and len(in_list) > max_in_list_size:
source_filters = [
FilterField.isin(column_name, _)
for _ in list_to_sublists(in_list, max_in_list_size)
]
return FilterField.or_(source_filters)
else:
return FilterField.isin(column_name, in_list)
def _is_padded_char(
self,
client: "ibis.backends.base.BaseBackend",
raw_column_metadata: dict,
column_name: str,
) -> bool:
# Clients only need to implement _is_char_type_padded method if they support padded strings.
return (
(
client.is_char_type_padded(raw_column_metadata[column_name])
if raw_column_metadata.get(column_name)
else False
)
if hasattr(client, "is_char_type_padded") and raw_column_metadata
else False
)
def clone(self):
cloned_builder = ValidationBuilder(self.config_manager)
cloned_builder.source_builder = deepcopy(self.source_builder)
cloned_builder.target_builder = deepcopy(self.target_builder)
cloned_builder.group_aliases = deepcopy(self.group_aliases)
cloned_builder.calculated_aliases = deepcopy(self.calculated_aliases)
cloned_builder.comparison_fields = deepcopy(self.comparison_fields)
cloned_builder._metadata = deepcopy(self._metadata)
return cloned_builder
@staticmethod
def get_query_builder(validation_type):
"""Return Query Builder object given validation type"""
if validation_type in [
"Column",
"GroupedColumn",
"Row",
"Schema",
"Custom-query",
]:
builder = QueryBuilder.build_count_validator()
else:
msg = "Validation Builder supplied unknown type: %s" % validation_type
raise ValueError(msg)
return builder
def get_metadata(self):
"""Metadata about the run and any validations it contains.
The validation metadata is populated as validations are added to this
builder.
Returns:
Dict[str, ValidationMetadata]:
A dictionary of validation name to ValidationMetadata.
"""
return self._metadata
def get_group_aliases(self):
"""Return List of String Aliases"""
return list(self.group_aliases.keys())
def get_primary_keys(self):
"""Return List of String Aliases"""
# do we need this?
return list(self.primary_keys.keys())
def get_calculated_aliases(self):
"""Return List of String Aliases"""
return self.calculated_aliases.keys()
def get_comparison_fields(self):
"""Return List of String Aliases"""
return self.comparison_fields.keys()
def get_grouped_alias_source_column(self, alias):
return self.group_aliases[alias][consts.CONFIG_SOURCE_COLUMN]
def get_grouped_alias_target_column(self, alias):
return self.group_aliases[alias][consts.CONFIG_TARGET_COLUMN]
def add_config_aggregates(self):
"""Add Aggregations to Query"""
aggregate_fields = self.config_manager.aggregates
for aggregate_field in aggregate_fields:
self.add_aggregate(aggregate_field)
def add_config_calculated_fields(self):
"""Add calculated fields to Query"""
# self.config_manager.calculated_fields could be None, hence the or [] below
for calc_field in self.config_manager.calculated_fields or []:
self.add_calc(calc_field)
def add_primary_keys(self, primary_keys=None):
primary_keys = primary_keys or self.config_manager.primary_keys
for field in primary_keys:
self.add_primary_key(field)
def add_comparison_fields(self, comparison_fields=None):
comparison_fields = comparison_fields or self.config_manager.comparison_fields
for field in comparison_fields:
self.add_comparison_field(field)
def add_config_query_groups(self, query_groups=None):
"""Add Grouped Columns to Query"""
grouped_fields = query_groups or self.config_manager.query_groups
for grouped_field in grouped_fields:
self.add_query_group(grouped_field)
def add_config_filters(self):
"""Add Filters to Query"""
filter_fields = self.config_manager.filters
for filter_field in filter_fields:
self.add_filter(filter_field)
def add_aggregate(self, aggregate_field):
"""Add Aggregate Field to Queries
Args:
aggregate_field (dict): A dictionary with source, target, and aggregation info.
"""
alias = aggregate_field[consts.CONFIG_FIELD_ALIAS]
source_field_name = aggregate_field[consts.CONFIG_SOURCE_COLUMN]
target_field_name = aggregate_field[consts.CONFIG_TARGET_COLUMN]
aggregate_type = aggregate_field.get(consts.CONFIG_TYPE)
cast = aggregate_field.get(consts.CONFIG_CAST)
if not hasattr(AggregateField, aggregate_type):
raise Exception("Unknown Aggregation Type: {}".format(aggregate_type))
source_agg = getattr(AggregateField, aggregate_type)(
field_name=source_field_name, alias=alias, cast=cast
)
target_agg = getattr(AggregateField, aggregate_type)(
field_name=target_field_name, alias=alias, cast=cast
)
self.source_builder.add_aggregate_field(source_agg)
self.target_builder.add_aggregate_field(target_agg)
self._metadata[alias] = metadata.ValidationMetadata(
validation_type=self.validation_type,
aggregation_type=aggregate_type,
source_table_schema=self.config_manager.source_schema,
source_table_name=self.config_manager.source_table,
target_table_schema=self.config_manager.target_schema,
target_table_name=self.config_manager.target_table,
source_column_name=source_field_name,
target_column_name=target_field_name,
primary_keys=self.config_manager.get_primary_keys_list(),
num_random_rows=self.config_manager.get_random_row_batch_size(),
threshold=self.config_manager.threshold,
)
def pop_grouped_fields(self):
"""Return grouped fields and reset configs."""
self.source_builder.grouped_fields = []
self.target_builder.grouped_fields = []
self.group_aliases = {}
return self.config_manager.query_groups
def add_query_group(self, grouped_field):
"""Add Grouped Field to Query
Args:
grouped_field (Dict): An object with source, target, and cast info
"""
alias = grouped_field[consts.CONFIG_FIELD_ALIAS]
source_field_name = grouped_field[consts.CONFIG_SOURCE_COLUMN]
target_field_name = grouped_field[consts.CONFIG_TARGET_COLUMN]
cast = grouped_field.get(consts.CONFIG_CAST)
source_field = GroupedField(
field_name=source_field_name, alias=alias, cast=cast
)
target_field = GroupedField(
field_name=target_field_name, alias=alias, cast=cast
)
self.source_builder.add_grouped_field(source_field)
self.target_builder.add_grouped_field(target_field)
self.group_aliases[alias] = grouped_field
def add_primary_key(self, primary_key):
"""Add ComparisonField to Queries
Args:
primary_key (Dict): An object with source, target, and cast info
"""
source_field_name = primary_key[consts.CONFIG_SOURCE_COLUMN]
target_field_name = primary_key[consts.CONFIG_TARGET_COLUMN]
# grab calc field metadata
alias = primary_key.get(consts.CONFIG_FIELD_ALIAS)
cast = primary_key.get(consts.CONFIG_CAST)
trim = self.config_manager.trim_string_pks()
# check if valid calc field and return correct object
source_field = ComparisonField(
field_name=source_field_name, alias=alias, cast=cast, trim=trim
)
target_field = ComparisonField(
field_name=target_field_name, alias=alias, cast=cast, trim=trim
)
self.source_builder.add_comparison_field(source_field)
self.target_builder.add_comparison_field(target_field)
self.primary_keys[alias] = primary_key
def add_filter(self, filter_field):
"""Add FilterField to Queries
Args:
filter_field (Dict): An object with source and target filter details
"""
if filter_field[consts.CONFIG_TYPE] == consts.FILTER_TYPE_CUSTOM:
source_filter = FilterField.custom(
filter_field[consts.CONFIG_FILTER_SOURCE]
)
target_filter = FilterField.custom(
filter_field[consts.CONFIG_FILTER_TARGET]
)
elif filter_field[consts.CONFIG_TYPE] == consts.FILTER_TYPE_EQUALS:
source_filter = FilterField.equal_to(
filter_field[consts.CONFIG_FILTER_SOURCE_COLUMN],
filter_field[consts.CONFIG_FILTER_SOURCE_VALUE],
)
target_filter = FilterField.equal_to(
filter_field[consts.CONFIG_FILTER_TARGET_COLUMN],
filter_field[consts.CONFIG_FILTER_TARGET_VALUE],
)
elif filter_field[consts.CONFIG_TYPE] == consts.FILTER_TYPE_ISIN:
source_filter = self._construct_isin_filter(
self.source_client,
filter_field[consts.CONFIG_FILTER_SOURCE_COLUMN],
filter_field[consts.CONFIG_FILTER_SOURCE_VALUE],
)
target_filter = self._construct_isin_filter(
self.target_client,
filter_field[consts.CONFIG_FILTER_TARGET_COLUMN],
filter_field[consts.CONFIG_FILTER_TARGET_VALUE],
)
# TODO(issues/40): Add metadata around filters
self.source_builder.add_filter_field(source_filter)
self.target_builder.add_filter_field(target_filter)
def add_comparison_field(self, comparison_field):
"""Add ComparionField to Queries
Args:
comparison_field (Dict): An object with source, target, and cast info
"""
source_field_name = comparison_field[consts.CONFIG_SOURCE_COLUMN]
target_field_name = comparison_field[consts.CONFIG_TARGET_COLUMN]
# grab calc field metadata
alias = comparison_field[consts.CONFIG_FIELD_ALIAS]
cast = comparison_field.get(consts.CONFIG_CAST)
source_field = ComparisonField(
field_name=source_field_name, alias=alias, cast=cast
)
target_field = ComparisonField(
field_name=target_field_name, alias=alias, cast=cast
)
# check if valid calc field and return correct object
self.source_builder.add_comparison_field(source_field)
self.target_builder.add_comparison_field(target_field)
self._metadata[alias] = metadata.ValidationMetadata(
aggregation_type=None,
validation_type=self.validation_type,
source_table_schema=self.config_manager.source_schema,
source_table_name=self.config_manager.source_table,
target_table_schema=self.config_manager.target_schema,
target_table_name=self.config_manager.target_table,
source_column_name=source_field_name,
target_column_name=target_field_name,
primary_keys=self.config_manager.get_primary_keys_list(),
num_random_rows=self.config_manager.get_random_row_batch_size(),
threshold=self.config_manager.threshold,
)
def _get_calc_type(
self,
calc_field: dict,
column_name: str,
client: "ibis.backends.base.BaseBackend",
raw_data_types: dict,
) -> str:
if calc_field[
consts.CONFIG_TYPE
] == consts.CALC_FIELD_LENGTH and self._is_padded_char(
client,
raw_data_types,
column_name,
):
calc_type = consts.CALC_FIELD_PADDED_CHAR_LENGTH
else:
calc_type = calc_field[consts.CONFIG_TYPE]
# Check if valid calc field and return correct object.
if not hasattr(CalculatedField, calc_type):
raise Exception("Unknown Calculation Type: {}".format(calc_type))
return calc_type
def add_calc(self, calc_field: dict):
"""Add CalculatedField to Queries
Args:
calc_field (Dict): An object with source, target, and cast info
"""
# Prepare source and target payloads
source_config = deepcopy(calc_field)
source_fields = calc_field[consts.CONFIG_CALCULATED_SOURCE_COLUMNS]
target_config = deepcopy(calc_field)
target_fields = calc_field[consts.CONFIG_CALCULATED_TARGET_COLUMNS]
# Grab calc field metadata
alias = calc_field[consts.CONFIG_FIELD_ALIAS]
source_calc_type = self._get_calc_type(
calc_field,
source_fields[0],
self.source_client,
self.config_manager.get_source_raw_data_types(),
)
target_calc_type = self._get_calc_type(
calc_field,
target_fields[0],
self.target_client,
self.config_manager.get_target_raw_data_types(),
)
source_field = getattr(CalculatedField, source_calc_type)(
config=source_config, fields=source_fields
)
target_field = getattr(CalculatedField, target_calc_type)(
config=target_config, fields=target_fields
)
self.source_builder.add_calculated_field(source_field)
self.target_builder.add_calculated_field(target_field)
# Register calc field under alias
self.calculated_aliases[alias] = calc_field
def get_source_query(self):
"""Return query for source validation"""
source_config = {
"data_client": self.source_client,
"schema_name": self.config_manager.source_schema,
"table_name": self.config_manager.source_table,
"source_query": self.config_manager.source_query,
}
if self.validation_type == consts.CUSTOM_QUERY:
table = self.config_manager.get_source_ibis_table_from_query()
else:
table = self.config_manager.get_source_ibis_table()
query = self.source_builder.compile(self.validation_type, table)
if self.verbose:
logging.info(source_config)
logging.info("-- ** Source Query ** --")
logging.info(query.compile())
return query
def get_target_query(self):
"""Return query for source validation"""
target_config = {
"data_client": self.target_client,
"schema_name": self.config_manager.target_schema,
"table_name": self.config_manager.target_table,
"target_query": self.config_manager.target_query,
}
if self.validation_type == consts.CUSTOM_QUERY:
table = self.config_manager.get_target_ibis_table_from_query()
else:
table = self.config_manager.get_target_ibis_table()
query = self.target_builder.compile(self.validation_type, table)
if self.verbose:
logging.info(target_config)
logging.info("-- ** Target Query ** --")
logging.info(query.compile())
return query
def add_query_limit(self):
"""Add a limit to the query results
**WARNING** this can skew results and should be used carefully
"""
limit = self.config_manager.query_limit
self.source_builder.limit = limit
self.target_builder.limit = limit