data_validation/find_tables.py (102 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import TYPE_CHECKING
from data_validation import (
cli_tools,
clients,
consts,
jellyfish_distance,
state_manager,
)
if TYPE_CHECKING:
import ibis
def _compare_match_tables(source_table_map, target_table_map, score_cutoff=0.8) -> list:
"""Return dict config object from matching tables."""
# TODO(dhercher): evaluate if improved comparison and score cutoffs should be used.
table_configs = []
target_keys = target_table_map.keys()
for source_key in source_table_map:
target_key = jellyfish_distance.extract_closest_match(
source_key, target_keys, score_cutoff=score_cutoff
)
if target_key is None:
continue
table_config = {
consts.CONFIG_SCHEMA_NAME: source_table_map[source_key][
consts.CONFIG_SCHEMA_NAME
],
consts.CONFIG_TABLE_NAME: source_table_map[source_key][
consts.CONFIG_TABLE_NAME
],
consts.CONFIG_TARGET_SCHEMA_NAME: target_table_map[target_key][
consts.CONFIG_SCHEMA_NAME
],
consts.CONFIG_TARGET_TABLE_NAME: target_table_map[target_key][
consts.CONFIG_TABLE_NAME
],
}
table_configs.append(table_config)
return table_configs
def _get_table_map(
client: "ibis.backends.base.BaseBackend", allowed_schemas=None, include_views=False
):
"""Return dict with searchable keys for table matching."""
table_map = {}
table_objs = clients.get_all_tables(
client, allowed_schemas=allowed_schemas, tables_only=(not include_views)
)
for table_obj in table_objs:
table_key = ".".join([t for t in table_obj if t])
table_map[table_key] = {
consts.CONFIG_SCHEMA_NAME: table_obj[0],
consts.CONFIG_TABLE_NAME: table_obj[1],
}
return table_map
def get_mapped_table_configs(
source_client: "ibis.backends.base.BaseBackend",
target_client: "ibis.backends.base.BaseBackend",
allowed_schemas: list = None,
include_views: bool = False,
score_cutoff: int = 1,
) -> list:
"""Get table list from each client and match them together into a single list of dicts."""
source_table_map = _get_table_map(
source_client, allowed_schemas=allowed_schemas, include_views=include_views
)
target_table_map = _get_table_map(target_client, include_views=include_views)
return _compare_match_tables(
source_table_map, target_table_map, score_cutoff=score_cutoff
)
def find_tables_using_string_matching(args) -> str:
"""Return JSON String with matched tables for use in validations."""
score_cutoff = args.score_cutoff or 1
mgr = state_manager.StateManager()
source_client = clients.get_data_client(mgr.get_connection_config(args.source_conn))
target_client = clients.get_data_client(mgr.get_connection_config(args.target_conn))
allowed_schemas = cli_tools.get_arg_list(args.allowed_schemas)
table_configs = get_mapped_table_configs(
source_client,
target_client,
allowed_schemas=allowed_schemas,
include_views=args.include_views,
score_cutoff=score_cutoff,
)
return json.dumps(table_configs)
def expand_tables_of_asterisk(
tables_list: list,
source_client: "ibis.backends.base.BaseBackend",
target_client: "ibis.backends.base.BaseBackend",
) -> list:
"""Pre-processes tables_mapping expanding any entries that are "schema.*". A shorthand for "find-tables" command.
We can be very specific in this function, we only expand arguments that are:
{"schema_name": (str), "table_name": "*"}.
No partial wildcards or args that include target_schema/table_name are expanded.
Args:
tables_list (list[dict]): List of schema/table name dicts.
source_client: Ibis client we can use to get a table list.
target_client: Ibis client we can use to get a table list.
Returns:
list: New version of tables_list with expanded "table_name": "*" entries.
"""
new_list = []
for mapping in tables_list:
if (
mapping
and mapping[consts.CONFIG_SCHEMA_NAME]
and mapping[consts.CONFIG_TABLE_NAME] == "*"
# Looking for schema.* without a target side qualifier.
and not mapping.get(consts.CONFIG_TARGET_SCHEMA_NAME, None)
and not mapping.get(consts.CONFIG_TARGET_TABLE_NAME, None)
):
# Expand the "*" to all tables in the schema.
expanded_tables = get_mapped_table_configs(
source_client,
target_client,
allowed_schemas=[mapping[consts.CONFIG_SCHEMA_NAME]],
include_views=False,
)
new_list.extend(expanded_tables)
else:
new_list.append(mapping)
return new_list