data_validation/schema_validation.py (296 lines of code) (raw):
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import itertools
import logging
import pandas
import re
from data_validation import metadata, consts, clients, exceptions
# Check for decimal data type with precision and/or scale. Permits hyphen in p/s for value ranges.
DECIMAL_PRECISION_SCALE_PATTERN = re.compile(
r"([!]?decimal)\(([0-9\-]+)(?:,[ ]*([0-9\-]+))?\)", re.I
)
# Extract lower/upper from a range of the format "0-2" or "12-18".
DECIMAL_PRECISION_SCALE_RANGE_PATTERN = re.compile(
r"([0-9]{2}|[0-9])(?:\-)([0-9]{2}|[0-9])"
)
class SchemaValidation(object):
def __init__(self, config_manager, run_metadata=None, verbose=False):
"""Initialize a SchemaValidation client
Args:
config_manager (ConfigManager): The ConfigManager for the validation.
run_metadata (RunMetadata): The RunMetadata for the validation.
verbose (bool): If verbose, the Data Validation client will print the queries run
"""
self.verbose = verbose
self.config_manager = config_manager
self.run_metadata = run_metadata or metadata.RunMetadata()
def execute(self):
"""Performs a validation between source and a target schema"""
ibis_source_schema = clients.get_ibis_table_schema(
self.config_manager.source_client,
self.config_manager.source_schema,
self.config_manager.source_table,
)
ibis_target_schema = clients.get_ibis_table_schema(
self.config_manager.target_client,
self.config_manager.target_schema,
self.config_manager.target_table,
)
source_fields = {}
for field_name, data_type in ibis_source_schema.items():
source_fields[field_name] = data_type
target_fields = {}
for field_name, data_type in ibis_target_schema.items():
target_fields[field_name] = data_type
results = schema_validation_matching(
source_fields,
target_fields,
self.config_manager.exclusion_columns,
self.config_manager.allow_list,
)
df = pandas.DataFrame(
results,
columns=[
consts.SOURCE_COLUMN_NAME,
consts.TARGET_COLUMN_NAME,
consts.SOURCE_AGG_VALUE,
consts.TARGET_AGG_VALUE,
consts.VALIDATION_STATUS,
],
)
# Update and Assign Metadata Values
self.run_metadata.end_time = datetime.datetime.now(datetime.timezone.utc)
df.insert(loc=0, column=consts.CONFIG_RUN_ID, value=self.run_metadata.run_id)
df.insert(loc=1, column=consts.VALIDATION_NAME, value="Schema")
df.insert(loc=2, column=consts.VALIDATION_TYPE, value="Schema")
df.insert(
loc=3,
column=consts.CONFIG_LABELS,
value=[self.run_metadata.labels for _ in range(len(df.index))],
)
df.insert(
loc=4, column=consts.CONFIG_START_TIME, value=self.run_metadata.start_time
)
df.insert(
loc=5, column=consts.CONFIG_END_TIME, value=self.run_metadata.end_time
)
df.insert(
loc=6,
column=consts.SOURCE_TABLE_NAME,
value=self.config_manager.full_source_table,
)
df.insert(
loc=7,
column=consts.TARGET_TABLE_NAME,
value=self.config_manager.full_target_table,
)
df.insert(loc=10, column=consts.AGGREGATION_TYPE, value="Schema")
# empty columns added due to changes on the results schema
df.insert(loc=14, column=consts.CONFIG_PRIMARY_KEYS, value=None)
df.insert(loc=15, column=consts.NUM_RANDOM_ROWS, value=None)
df.insert(loc=16, column=consts.GROUP_BY_COLUMNS, value=None)
df.insert(loc=17, column=consts.VALIDATION_DIFFERENCE, value=None)
df.insert(loc=18, column=consts.VALIDATION_PCT_THRESHOLD, value=None)
return df
def schema_validation_matching(
source_fields, target_fields, exclusion_fields, allow_list
):
"""Compare schemas between two dictionary objects"""
results = []
# Apply the casefold() function to lowercase the keys of source and target
source_fields_casefold = {
source_field_name.casefold(): source_field_type
for source_field_name, source_field_type in source_fields.items()
}
target_fields_casefold = {
target_field_name.casefold(): target_field_type
for target_field_name, target_field_type in target_fields.items()
}
if exclusion_fields is not None:
for field in exclusion_fields:
source_fields_casefold.pop(field, None)
target_fields_casefold.pop(field, None)
# Allow list map in case of incompatible data types in source and target
allow_list_map = parse_allow_list(allow_list)
# Go through each source and check if target exists and matches
for source_field_name, source_field_type in source_fields_casefold.items():
if source_field_name not in target_fields_casefold:
# Target field doesn't exist
results.append(
[
source_field_name,
"N/A",
str(source_field_type),
"N/A",
consts.VALIDATION_STATUS_FAIL,
]
)
continue
target_field_type = target_fields_casefold[source_field_name]
if source_field_type == target_field_type:
# Target data type matches
results.append(
[
source_field_name,
source_field_name,
str(source_field_type),
str(target_field_type),
consts.VALIDATION_STATUS_SUCCESS,
]
)
elif (
string_val(source_field_type) in allow_list_map
and string_val(target_field_type)
in allow_list_map[string_val(source_field_type)]
):
# Data type pair match an allow-list pair.
results.append(
[
source_field_name,
source_field_name,
string_val(source_field_type),
str(target_field_type),
consts.VALIDATION_STATUS_SUCCESS,
]
)
else:
# Target data type mismatch
(higher_precision, lower_precision,) = parse_n_validate_datatypes(
string_val(source_field_type), string_val(target_field_type)
)
if higher_precision:
# If the target precision is higher then the validation is acceptable but worth a warning.
logging.warning(
"Source and target data type has precision mismatch: %s - %s",
string_val(source_field_type),
str(target_field_type),
)
results.append(
[
source_field_name,
source_field_name,
string_val(source_field_type),
str(target_field_type),
consts.VALIDATION_STATUS_SUCCESS,
]
)
else:
results.append(
[
source_field_name,
source_field_name,
str(source_field_type),
str(target_field_type),
consts.VALIDATION_STATUS_FAIL,
]
)
# Source field doesn't exist
for target_field_name, target_field_type in target_fields_casefold.items():
if target_field_name not in source_fields_casefold:
results.append(
[
"N/A",
target_field_name,
"N/A",
str(target_field_type),
consts.VALIDATION_STATUS_FAIL,
]
)
return results
def split_allow_list_str(allow_list_str: str) -> list:
"""Split the allow list string into a list of datatype:datatype tuples."""
# I've not moved this patter to a compiled constant because it should only
# happen once per command and I felt splitting the pattern into variables
# aided readability.
nullable_pattern = r"!?"
precision_scale_pattern = r"(?:\((?:[0-9 ,\-]+|'UTC')\))?"
data_type_pattern = nullable_pattern + r"[a-z0-9 ]+" + precision_scale_pattern
csv_split_pattern = data_type_pattern + r":" + data_type_pattern
data_type_pairs = [
_.replace(" ", "").split(":")
for _ in re.findall(csv_split_pattern, allow_list_str, re.I)
]
invalid_pairs = [_ for _ in data_type_pairs if len(_) != 2]
if invalid_pairs:
raise exceptions.SchemaValidationException(
f"Invalid data type pairs: {invalid_pairs}"
)
return data_type_pairs
def expand_precision_range(s: str) -> list:
"""Expand an integer range (e.g. "0-3") to a list (e.g. ["0", "1", "2", "3"])."""
m_range = DECIMAL_PRECISION_SCALE_RANGE_PATTERN.match(s)
if not m_range:
return [s]
try:
p_lower = int(m_range.group(1))
p_upper = int(m_range.group(2))
if p_lower >= p_upper:
raise exceptions.SchemaValidationException(
f"Invalid allow list data type precision/scale: Lower value {p_lower} >= upper value {p_upper}"
)
return [str(_) for _ in range(p_lower, p_upper + 1)]
except ValueError as e:
raise exceptions.SchemaValidationException(
f"Invalid allow list data type precision/scale: {s}"
) from e
def expand_precision_or_scale_range(data_type: str) -> list:
"""Take a data type and example any precision/scale range.
For example "decimal(1-3,0)" becomes:
["decimal(1,0)", "decimal(2,0)", "decimal(3,0)"]"""
m = DECIMAL_PRECISION_SCALE_PATTERN.match(data_type.replace(" ", ""))
if not m:
return [data_type]
if len(m.groups()) != 3:
raise exceptions.SchemaValidationException(
f"Badly formatted data type: {data_type}"
)
type_name, p, s = m.groups()
p_list = expand_precision_range(p)
if s:
s_list = expand_precision_range(s)
return_list = [
f"{type_name}({p},{s})" for p, s in itertools.product(p_list, s_list)
]
else:
return_list = [f"{type_name}({_})" for _ in p_list]
return return_list
def parse_allow_list(st: str) -> dict:
"""Convert allow-list data type pairs into a dictionary like {key[value1, value2, etc], }"""
def expand_allow_list_ranges(data_type_pairs: list) -> list:
expanded_pairs = []
for dt1, dt2 in data_type_pairs:
dt1_list = expand_precision_or_scale_range(dt1)
dt2_list = expand_precision_or_scale_range(dt2)
expanded_pairs.extend(
[(_[0], _[1]) for _ in itertools.product(dt1_list, dt2_list)]
)
return expanded_pairs
def convert_pairs_to_dict(expanded_pairs: list) -> dict:
"""Take the list data type tuples and convert them into a dictionary keyed on source data type.
For example:
[('decimal(2,0)', 'int64'), ('decimal(2,0)', 'int32')]
becomes:
{'decimal(2,0)': ['int64', 'int32']}
"""
return_pairs = {}
for dt1, dt2 in expanded_pairs:
if dt1 in return_pairs:
return_pairs[dt1].append(dt2)
else:
return_pairs[dt1] = [dt2]
return return_pairs
data_type_pairs = split_allow_list_str(st)
expanded_pairs = expand_allow_list_ranges(data_type_pairs)
return_pairs = convert_pairs_to_dict(expanded_pairs)
return return_pairs
# typea data types: int8,int16
def get_typea_numeric_sustr(st):
nums = []
if "(" in st:
return -1
for i in range(len(st)):
if st[i].isdigit():
nums.append(st[i])
num = "".join(nums)
if num == "":
return -1
return int(num)
# typeb data types: Decimal(10,2)
def get_typeb_numeric_sustr(st: str) -> tuple:
m = DECIMAL_PRECISION_SCALE_PATTERN.match(st.replace(" ", ""))
if not m:
return -1, -1
_, p, s = m.groups()
if s is None:
s = 0
return int(p), int(s)
def string_val(st):
return str(st).replace(" ", "")
def validate_typeb_vals(source, target):
if source[0] > target[0] or source[1] > target[1]:
return False, True
elif source[0] == target[0] and source[1] == target[1]:
return False, False
return True, False
def strip_null(st):
return st.replace("!", "")
def parse_n_validate_datatypes(source, target) -> tuple:
"""
Args:
source: Source table datatype string
target: Target table datatype string
Returns:
bool:target has higher precision value
bool:target has lower precision value
"""
if strip_null(source) == target:
return False, False
if "(" in source and "(" in target:
typeb_source = get_typeb_numeric_sustr(source)
typeb_target = get_typeb_numeric_sustr(target)
higher_precision, lower_precision = validate_typeb_vals(
typeb_source, typeb_target
)
return higher_precision, lower_precision
source_num = get_typea_numeric_sustr(source)
target_num = get_typea_numeric_sustr(target)
# In case of no bits specified, we will not match for precisions
if source_num == -1 or target_num == -1:
return False, False
if source_num == target_num:
return False, False
elif source_num > target_num:
return False, True
return False, False