# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import string
import random
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Tuple

import ibis.expr.datatypes as dt
import yaml

from data_validation import clients, consts, gcs_helper, state_manager
from data_validation.result_handlers.factory import build_result_handler
from data_validation.validation_builder import ValidationBuilder

if TYPE_CHECKING:
    import ibis.expr.types.Table


class ConfigManager(object):
    _config: dict = None
    _source_conn = None
    _target_conn = None
    _state_manager = None
    source_client = None
    target_client = None

    def __init__(self, config, source_client=None, target_client=None, verbose=False):
        """Initialize a ConfigManager client which supplies the
            source and target queries to run.

        Args:
            config (Dict): The Validation config supplied
            source_client (IbisClient): The Ibis client for the source DB
            target_client (IbisClient): The Ibis client for the target DB
            verbose (Bool): If verbose, the Data Validation client will print queries run
            google_credentials (google.auth.credentials.Credentials):
                Explicit credentials to use in case default credentials
                aren't working properly.
        """
        self._state_manager = state_manager.StateManager()
        self._config = config

        self.source_client = source_client or clients.get_data_client(
            self.get_source_connection()
        )
        self.target_client = target_client or clients.get_data_client(
            self.get_target_connection()
        )

        self.verbose = verbose
        if self.validation_type not in consts.CONFIG_TYPES:
            raise ValueError(f"Unknown Configuration Type: {self.validation_type}")
        self._comparison_max_col_length = None
        # For some engines we need to know the actual raw data type rather than the Ibis canonical type.
        self._source_raw_data_types = None
        self._target_raw_data_types = None

    @property
    def config(self):
        """Return config object."""
        return self._config

    def get_source_connection(self):
        """Return source connection object."""
        if not self._source_conn:
            if self._config.get(consts.CONFIG_SOURCE_CONN):
                self._source_conn = self._config.get(consts.CONFIG_SOURCE_CONN)
            else:
                conn_name = self._config.get(consts.CONFIG_SOURCE_CONN_NAME)
                self._source_conn = self._state_manager.get_connection_config(conn_name)

        return self._source_conn

    def get_target_connection(self):
        """Return target connection object."""
        if not self._target_conn:
            if self._config.get(consts.CONFIG_TARGET_CONN):
                self._target_conn = self._config.get(consts.CONFIG_TARGET_CONN)
            else:
                conn_name = self._config.get(consts.CONFIG_TARGET_CONN_NAME)
                self._target_conn = self._state_manager.get_connection_config(conn_name)

        return self._target_conn

    def get_source_raw_data_types(self) -> Dict[str, Tuple]:
        """Return raw data type information from source system.

        The raw data type is the source/target engine type, for example it might
        be "NCLOB" or "char" when the Ibis type simply states "string".
        The data is cached in state when fetched for the first time.
        The retuen value is keyed on the casefolded column name and the tuple is
        the remaining 6 elements of the DB API cursor description specification."""
        if self._source_raw_data_types is None:
            if hasattr(self.source_client, "raw_column_metadata"):
                raw_data_types = self.source_client.raw_column_metadata(
                    database=self.source_schema,
                    table=self.source_table,
                    query=self.source_query,
                )
                self._source_raw_data_types = {
                    _[0].casefold(): _[1:] for _ in raw_data_types
                }
            else:
                self._source_raw_data_types = {}
        return self._source_raw_data_types

    def get_target_raw_data_types(self) -> Dict[str, Tuple]:
        """Return raw data type information from target system.

        The raw data type is the source/target engine type, for example it might
        be "NCLOB" or "char" when the Ibis type simply states "string".
        The data is cached in state when fetched for the first time.
        The retuen value is keyed on the casefolded column name and the tuple is
        the remaining 6 elements of the DB API cursor description specification."""
        if self._target_raw_data_types is None:
            if hasattr(self.target_client, "raw_column_metadata"):
                raw_data_types = self.target_client.raw_column_metadata(
                    database=self.target_schema,
                    table=self.target_table,
                    query=self.target_query,
                )
                self._target_raw_data_types = {
                    _[0].casefold(): _[1:] for _ in raw_data_types
                }
            else:
                self._target_raw_data_types = {}
        return self._target_raw_data_types

    def close_client_connections(self):
        """Attempt to clean up any source/target connections, based on the client types.

        Not all clients are covered here, we at least have Oracle and PostgreSQL for which we
        have seen connections being accumulated.
        https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/1195
        """
        try:
            if self.source_client and self.source_client.name in ("oracle", "postgres"):
                self.source_client.con.dispose()
            if self.target_client and self.target_client.name in ("oracle", "postgres"):
                self.target_client.con.dispose()
        except Exception as exc:
            # No need to reraise, we can silently fail if exiting throws up an issue.
            logging.warning("Exception closing connections: %s", str(exc))

    @property
    def validation_type(self):
        """Return string validation type (Column|Schema)."""
        return self._config[consts.CONFIG_TYPE]

    def use_random_rows(self):
        """Return if the validation should use a random row filter."""
        return self._config.get(consts.CONFIG_USE_RANDOM_ROWS) or False

    def random_row_batch_size(self):
        """Return batch size for random row filter."""
        return int(
            self._config.get(consts.CONFIG_RANDOM_ROW_BATCH_SIZE)
            or consts.DEFAULT_NUM_RANDOM_ROWS
        )

    def get_random_row_batch_size(self):
        """Return number of random rows or None."""
        return self.random_row_batch_size() if self.use_random_rows() else None

    def trim_string_pks(self):
        """Return if the validation should trim string primary keys."""
        return self._config.get(consts.CONFIG_TRIM_STRING_PKS) or False

    def case_insensitive_match(self):
        """Return if the validation should perform a case insensitive match."""
        return self._config.get(consts.CONFIG_CASE_INSENSITIVE_MATCH) or False

    @property
    def max_recursive_query_size(self):
        """Return Aggregates from Config"""
        return self._config.get(consts.CONFIG_MAX_RECURSIVE_QUERY_SIZE, 50000)

    @property
    def aggregates(self):
        """Return Aggregates from Config"""
        return self._config.get(consts.CONFIG_AGGREGATES, [])

    def append_aggregates(self, aggregate_configs):
        """Append aggregate configs to existing config."""
        self._config[consts.CONFIG_AGGREGATES] = self.aggregates + aggregate_configs

    @property
    def calculated_fields(self):
        return self._config.get(consts.CONFIG_CALCULATED_FIELDS, [])

    def append_calculated_fields(self, calculated_configs):
        self._config[consts.CONFIG_CALCULATED_FIELDS] = (
            self.calculated_fields + calculated_configs
        )

    @property
    def query_groups(self):
        """Return Query Groups from Config"""
        return self._config.get(consts.CONFIG_GROUPED_COLUMNS, [])

    def append_query_groups(self, grouped_column_configs):
        """Append grouped configs to existing config."""
        self._config[consts.CONFIG_GROUPED_COLUMNS] = (
            self.query_groups + grouped_column_configs
        )

    @property
    def custom_query_type(self):
        """Return custom query type from config"""
        return self._config.get(consts.CONFIG_CUSTOM_QUERY_TYPE, "")

    def append_custom_query_type(self, custom_query_type):
        """Append custom query type config to existing config."""
        self._config[consts.CONFIG_CUSTOM_QUERY_TYPE] = (
            self.custom_query_type + custom_query_type
        )

    @property
    def source_query_file(self):
        """Return SQL Query File from Config"""
        return self._config.get(consts.CONFIG_SOURCE_QUERY_FILE, [])

    def append_source_query_file(self, query_file_configs):
        """Append grouped configs to existing config."""
        self._config[consts.CONFIG_SOURCE_QUERY_FILE] = (
            self.source_query_file + query_file_configs
        )

    @property
    def target_query_file(self):
        """Return SQL Query File from Config"""
        return self._config.get(consts.CONFIG_TARGET_QUERY_FILE, [])

    def append_target_query_file(self, query_file_configs):
        """Append grouped configs to existing config."""
        self._config[consts.CONFIG_TARGET_QUERY_FILE] = (
            self.target_query_file + query_file_configs
        )

    @property
    def primary_keys(self):
        """Return Primary keys from Config"""
        return self._config.get(consts.CONFIG_PRIMARY_KEYS, [])

    def append_primary_keys(self, primary_key_configs):
        """Append primary key configs to existing config."""
        self._config[consts.CONFIG_PRIMARY_KEYS] = (
            self.primary_keys + primary_key_configs
        )

    def get_primary_keys_list(self):
        """Return list of primary key column names"""
        return [key[consts.CONFIG_SOURCE_COLUMN] for key in self.primary_keys]

    @property
    def comparison_fields(self):
        """Return fields from Config"""
        return self._config.get(consts.CONFIG_COMPARISON_FIELDS, [])

    def append_comparison_fields(self, field_configs):
        """Append field configs to existing config."""
        self._config[consts.CONFIG_COMPARISON_FIELDS] = (
            self.comparison_fields + field_configs
        )

    @property
    def concat(self):
        """Return field from Config"""
        return self._config.get(consts.CONFIG_ROW_CONCAT, [])

    @property
    def hash(self):
        """Return field from Config"""
        return self._config.get(consts.CONFIG_ROW_HASH, [])

    @property
    def run_id(self):
        """Return field from Config"""
        return self._config.get(consts.CONFIG_RUN_ID, None)

    @property
    def filters(self):
        """Return Filters from Config"""
        return self._config.get(consts.CONFIG_FILTERS, [])

    @property
    def source_schema(self):
        """Return string value of source schema."""
        if self.source_client._source_type == "FileSystem":
            return None
        return self._config.get(consts.CONFIG_SCHEMA_NAME, None)

    @property
    def source_table(self):
        """Return string value of source table."""
        return self._config[consts.CONFIG_TABLE_NAME]

    @property
    def target_schema(self):
        """Return string value of target schema."""
        if self.target_client._source_type == "FileSystem":
            return None
        return self._config.get(consts.CONFIG_TARGET_SCHEMA_NAME, self.source_schema)

    @property
    def target_table(self):
        """Return string value of target table."""
        return self._config.get(
            consts.CONFIG_TARGET_TABLE_NAME, self._config[consts.CONFIG_TABLE_NAME]
        )

    @property
    def full_target_table(self):
        """Return string value of fully qualified target table."""
        if self.target_schema:
            return self.target_schema + "." + self.target_table
        else:
            return self.target_table

    @property
    def full_source_table(self):
        """Return string value of target table."""
        if self.source_table and self.source_schema:
            return self.source_schema + "." + self.source_table
        elif self.source_table:
            return self.source_table
        else:
            return f"custom.{''.join(random.choice(string.ascii_lowercase) for _ in range(5))}"

    @property
    def labels(self):
        """Return labels."""
        return self._config.get(consts.CONFIG_LABELS, [])

    @property
    def result_handler_config(self):
        """Return int limit for query executions."""
        return self._config.get(consts.CONFIG_RESULT_HANDLER) or {}

    @property
    def query_limit(self):
        """Return int limit for query executions."""
        return self._config.get(consts.CONFIG_LIMIT)

    @property
    def threshold(self):
        """Return threshold from Config"""
        return self._config.get(consts.CONFIG_THRESHOLD, 0.0)

    @property
    def source_query(self):
        return self._config.get(consts.CONFIG_SOURCE_QUERY, None)

    def append_source_query(self, source_query):
        self._config["source_query"] = source_query

    @property
    def target_query(self):
        return self._config.get(consts.CONFIG_TARGET_QUERY, None)

    def append_target_query(self, target_query):
        self._config["target_query"] = target_query

    @property
    def exclusion_columns(self):
        """Return the exclusion columns from Config"""
        return self._config.get(consts.CONFIG_EXCLUSION_COLUMNS, [])

    @property
    def allow_list(self):
        """Return the allow_list from Config"""
        return self._config.get(consts.CONFIG_ALLOW_LIST, "")

    @property
    def filter_status(self):
        """Return filter status list from Config"""
        return self._config.get(consts.CONFIG_FILTER_STATUS, None)

    def append_exclusion_columns(self, column_configs):
        """Append exclusion columns to existing config."""
        self._config[consts.CONFIG_EXCLUSION_COLUMNS] = (
            self.exclusion_columns + column_configs
        )

    def append_allow_list(
        self, allow_list: Union[str, None], allow_list_file: Union[str, None]
    ):
        """Append datatype allow_list to existing config."""
        full_allow_list = []
        if allow_list:
            allow_list = allow_list.replace(" ", "")
            full_allow_list.append(allow_list)
        if allow_list_file:
            try:
                allow_list_yaml = gcs_helper.read_file(allow_list_file)
            except FileNotFoundError as e:
                raise ValueError(
                    "Cannot locate --allow-list-file: {allow_list_file}"
                ) from e
            allow_list_dict = yaml.safe_load(allow_list_yaml)
            full_allow_list.append(
                ",".join([f"{_[0]}:{_[1]}" for _ in allow_list_dict.items()])
            )

        self._config[consts.CONFIG_ALLOW_LIST] = ",".join(full_allow_list)

    def get_source_ibis_table(self):
        """Return IbisTable from source."""
        if not hasattr(self, "_source_ibis_table"):
            self._source_ibis_table = clients.get_ibis_table(
                self.source_client, self.source_schema, self.source_table
            )
        return self._source_ibis_table

    def get_source_ibis_table_from_query(self):
        """Return IbisTable from source."""
        if not hasattr(self, "_source_ibis_table"):
            self._source_ibis_table = clients.get_ibis_query(
                self.source_client, self.source_query
            )
        return self._source_ibis_table

    def get_source_ibis_calculated_table(self, depth=None):
        """Return mutated IbisTable from source
        depth: Int the depth of subquery requested"""
        if self.validation_type == consts.CUSTOM_QUERY:
            table = self.get_source_ibis_table_from_query()
        else:
            table = self.get_source_ibis_table()
        vb = ValidationBuilder(self)
        calculated_table = table.mutate(
            vb.source_builder.compile_calculated_fields(table, n=depth)
        )

        return calculated_table

    def get_target_ibis_table(self):
        """Return IbisTable from target."""
        if not hasattr(self, "_target_ibis_table"):
            self._target_ibis_table = clients.get_ibis_table(
                self.target_client, self.target_schema, self.target_table
            )
        return self._target_ibis_table

    def get_target_ibis_table_from_query(self):
        """Return IbisTable from source."""
        if not hasattr(self, "_target_ibis_table"):
            self._target_ibis_table = clients.get_ibis_query(
                self.target_client, self.target_query
            )
        return self._target_ibis_table

    def get_target_ibis_calculated_table(self, depth=None):
        """Return mutated IbisTable from target
        n: Int the depth of subquery requested"""
        if self.validation_type == consts.CUSTOM_QUERY:
            table = self.get_target_ibis_table_from_query()
        else:
            table = self.get_target_ibis_table()
        vb = ValidationBuilder(self)
        calculated_table = table.mutate(
            vb.target_builder.compile_calculated_fields(table, n=depth)
        )

        return calculated_table

    def get_yaml_validation_block(self):
        """Return Dict object formatted for a Yaml file."""
        config = copy.deepcopy(self.config)

        config.pop(consts.CONFIG_SOURCE_CONN, None)
        config.pop(consts.CONFIG_TARGET_CONN, None)

        config.pop(consts.CONFIG_SOURCE_CONN_NAME, None)
        config.pop(consts.CONFIG_TARGET_CONN_NAME, None)

        config.pop(consts.CONFIG_RESULT_HANDLER, None)

        return config

    def get_result_handler(self):
        """Return ResultHandler instance from supplied config."""
        return build_result_handler(
            self.result_handler_config,
            self.config[consts.CONFIG_TYPE],
            self.filter_status,
            text_format=self._config.get(
                consts.CONFIG_FORMAT, consts.FORMAT_TYPE_TABLE
            ),
        )

    @staticmethod
    def build_config_manager(
        config_type,
        source_conn_name,
        target_conn_name,
        table_obj,
        labels,
        threshold,
        format,
        use_random_rows=None,
        random_row_batch_size=None,
        source_client=None,
        target_client=None,
        result_handler_config=None,
        filter_config=None,
        filter_status=None,
        trim_string_pks=None,
        case_insensitive_match=None,
        concat=None,
        hash=None,
        run_id=None,
        verbose=False,
    ):
        if isinstance(filter_config, dict):
            filter_config = [filter_config]

        """Return a ConfigManager instance with available config."""
        config = {
            consts.CONFIG_TYPE: config_type,
            consts.CONFIG_SOURCE_CONN_NAME: source_conn_name,
            consts.CONFIG_TARGET_CONN_NAME: target_conn_name,
            consts.CONFIG_TABLE_NAME: table_obj.get(consts.CONFIG_TABLE_NAME, None),
            consts.CONFIG_SCHEMA_NAME: table_obj.get(consts.CONFIG_SCHEMA_NAME, None),
            consts.CONFIG_TARGET_SCHEMA_NAME: table_obj.get(
                consts.CONFIG_TARGET_SCHEMA_NAME,
                table_obj.get(consts.CONFIG_SCHEMA_NAME, None),
            ),
            consts.CONFIG_TARGET_TABLE_NAME: table_obj.get(
                consts.CONFIG_TARGET_TABLE_NAME,
                table_obj.get(consts.CONFIG_TABLE_NAME, None),
            ),
            consts.CONFIG_LABELS: labels,
            consts.CONFIG_THRESHOLD: threshold,
            consts.CONFIG_FORMAT: format,
            consts.CONFIG_RESULT_HANDLER: result_handler_config,
            consts.CONFIG_FILTERS: filter_config,
            consts.CONFIG_USE_RANDOM_ROWS: use_random_rows,
            consts.CONFIG_RANDOM_ROW_BATCH_SIZE: random_row_batch_size,
            consts.CONFIG_FILTER_STATUS: filter_status,
            consts.CONFIG_TRIM_STRING_PKS: trim_string_pks,
            consts.CONFIG_CASE_INSENSITIVE_MATCH: case_insensitive_match,
            consts.CONFIG_ROW_CONCAT: concat,
            consts.CONFIG_ROW_HASH: hash,
            consts.CONFIG_RUN_ID: run_id,
        }

        return ConfigManager(
            config,
            source_client=source_client,
            target_client=target_client,
            verbose=verbose,
        )

    def add_rstrip_to_comp_fields(self, comparison_fields: List[str]) -> List[str]:
        """As per #1190, add an rstrip calculated field for Teradata string comparison fields.

        Parameters:
            comparison_fields: List[str] of comparison field columns
        Returns:
            comp_fields_with_aliases: List[str] of comparison field columns with rstrip aliases
        """
        source_table = self.get_source_ibis_calculated_table()
        target_table = self.get_target_ibis_calculated_table()
        source_table_schema = {k: v for k, v in source_table.schema().items()}
        target_table_schema = {k: v for k, v in target_table.schema().items()}
        casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns}
        casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns}

        comp_fields_with_aliases = []
        calculated_configs = []
        for field in comparison_fields:
            if field.casefold() not in casefold_source_columns:
                raise ValueError(f"Column DNE in source: {field}")
            if field.casefold() not in casefold_target_columns:
                raise ValueError(f"Column DNE in target: {field}")

            source_ibis_type = source_table[
                casefold_source_columns[field.casefold()]
            ].type()
            target_ibis_type = target_table[
                casefold_target_columns[field.casefold()]
            ].type()

            if (
                source_ibis_type.is_string() or target_ibis_type.is_string()
            ) and not self._comp_field_cast(
                # Do not add rstrip if the column is a bool or UUID hiding in a string.
                source_table_schema,
                target_table_schema,
                field,
            ):
                logging.info(
                    f"Adding rtrim() to string comparison field `{field.casefold()}` due to Teradata CHAR padding."
                )
                alias = f"rstrip__{field.casefold()}"
                calculated_configs.append(
                    self.build_config_calculated_fields(
                        [casefold_source_columns[field.casefold()]],
                        [casefold_target_columns[field.casefold()]],
                        consts.CALC_FIELD_RSTRIP,
                        alias,
                        0,
                    )
                )
                comp_fields_with_aliases.append(alias)
            else:
                comp_fields_with_aliases.append(field)

        self.append_calculated_fields(calculated_configs)
        return comp_fields_with_aliases

    def _comp_field_cast(
        self, source_table_schema: dict, target_table_schema: dict, field: str
    ) -> str:
        # We check below if the field exists because sometimes it is a computed name
        # like "concat__all" which is not in the real table.
        source_type = (
            source_table_schema[field] if field in source_table_schema else None
        )
        target_type = (
            target_table_schema[field] if field in target_table_schema else None
        )
        if self._is_bool(source_type, target_type):
            return "bool"
        elif self._is_uuid(source_type, target_type):
            return consts.CONFIG_CAST_UUID_STRING
        return None

    def _is_bool(
        self, source_type: Union[str, dt.DataType], target_type: Union[str, dt.DataType]
    ) -> bool:
        """Returns whether column is BOOLEAN based on either source or target data type.

        We do this because some engines don't have a BOOLEAN type, therefore BOOLEAN on one side
        means both sides need to be BOOLEAN aware."""
        if isinstance(source_type, str):
            return any(_ in ["bool", "!bool"] for _ in [source_type, target_type])
        else:
            return bool(
                isinstance(source_type, dt.Boolean)
                or isinstance(target_type, dt.Boolean)
            )

    def _is_uuid(
        self, source_type: Union[str, dt.DataType], target_type: Union[str, dt.DataType]
    ) -> bool:
        """Returns whether column is UUID based on either source or target data type.

        We do this because some engines don't have a UUID type, therefore UUID on one side
        means both sides are UUID. i.e. we use any() not all()."""
        if isinstance(source_type, str):
            return any(_ in ["uuid", "!uuid"] for _ in [source_type, target_type])
        else:
            return bool(
                isinstance(source_type, dt.UUID) or isinstance(target_type, dt.UUID)
            )

    def build_config_comparison_fields(self, fields, depth=None):
        """Return list of field config objects."""
        field_configs = []
        source_table = self.get_source_ibis_calculated_table()
        target_table = self.get_target_ibis_calculated_table()
        source_table_schema = {k: v for k, v in source_table.schema().items()}
        target_table_schema = {k: v for k, v in target_table.schema().items()}
        casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns}
        casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns}

        for field in fields:
            cast_type = self._comp_field_cast(
                source_table_schema, target_table_schema, field
            )
            column_config = {
                consts.CONFIG_SOURCE_COLUMN: casefold_source_columns.get(
                    field.casefold(), field
                ),
                consts.CONFIG_TARGET_COLUMN: casefold_target_columns.get(
                    field.casefold(), field
                ),
                consts.CONFIG_FIELD_ALIAS: field,
                consts.CONFIG_CAST: cast_type,
            }
            field_configs.append(column_config)
        return field_configs

    def build_column_configs(self, columns):
        """Return list of column config objects."""
        column_configs = []
        source_table = self.get_source_ibis_calculated_table()
        target_table = self.get_target_ibis_calculated_table()
        casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns}
        casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns}

        for column in columns:
            if column.casefold() not in casefold_source_columns:
                raise ValueError(f"Column DNE in source: {column}")
            if column.casefold() not in casefold_target_columns:
                raise ValueError(f"Column DNE in target: {column}")

            source_ibis_type = source_table[
                casefold_source_columns[column.casefold()]
            ].type()
            target_ibis_type = target_table[
                casefold_target_columns[column.casefold()]
            ].type()
            cast_type = self._key_column_needs_casting_to_string(
                source_ibis_type, target_ibis_type
            )

            column_config = {
                consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column.casefold()],
                consts.CONFIG_TARGET_COLUMN: casefold_target_columns[column.casefold()],
                consts.CONFIG_FIELD_ALIAS: column,
                consts.CONFIG_CAST: cast_type,
            }
            column_configs.append(column_config)

        return column_configs

    def build_config_count_aggregate(self):
        """Return dict aggregate for COUNT(*)."""
        aggregate_config = {
            consts.CONFIG_SOURCE_COLUMN: None,
            consts.CONFIG_TARGET_COLUMN: None,
            consts.CONFIG_FIELD_ALIAS: "count",
            consts.CONFIG_TYPE: "count",
        }

        return aggregate_config

    def _prefix_calc_col_name(
        self, column_name: str, prefix: str, column_number: int
    ) -> str:
        """Prefix a column name but protect final string from overflowing SQL engine identifier length limit."""
        new_name = f"{prefix}__{column_name}"
        if len(new_name) > self._get_comparison_max_col_length():
            # Use an abstract name for the calculated column to avoid composing invalid SQL.
            new_name = f"{prefix}__dvt_calc_col_{column_number}"
        return new_name

    def build_and_append_pre_agg_calc_config(
        self,
        source_column,
        target_column,
        calc_func,
        column_position,
        cast_type: str = None,
        depth: int = 0,
    ):
        """Create calculated field config used as a pre-aggregation step. Appends to calculated fields if does not already exist and returns created config."""
        calculated_config = {
            consts.CONFIG_CALCULATED_SOURCE_COLUMNS: [source_column],
            consts.CONFIG_CALCULATED_TARGET_COLUMNS: [target_column],
            consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name(
                source_column, calc_func, column_position
            ),
            consts.CONFIG_TYPE: calc_func,
            consts.CONFIG_DEPTH: depth,
        }

        if calc_func == consts.CONFIG_CAST and cast_type is not None:
            calculated_config[consts.CONFIG_DEFAULT_CAST] = cast_type
            calculated_config[consts.CONFIG_FIELD_ALIAS] = self._prefix_calc_col_name(
                source_column, f"{calc_func}_{cast_type}", column_position
            )

        existing_calc_fields = [
            config[consts.CONFIG_FIELD_ALIAS] for config in self.calculated_fields
        ]

        if calculated_config[consts.CONFIG_FIELD_ALIAS] not in existing_calc_fields:
            self.append_calculated_fields([calculated_config])
        return calculated_config

    def append_pre_agg_calc_field(
        self,
        source_column: str,
        target_column: str,
        agg_type: str,
        column_type: str,
        target_column_type: str,
        column_position: int,
    ) -> dict:
        """Append calculated field for length() or epoch_seconds(timestamp) for preprocessing before column validation aggregation."""
        depth = 0
        cast_type = None
        final_cast_type = None
        if any(_ in ["json", "!json"] for _ in [column_type, target_column_type]):
            # JSON data which needs casting to string before we apply a length function.
            pre_calculated_config = self.build_and_append_pre_agg_calc_config(
                source_column,
                target_column,
                consts.CONFIG_CAST,
                column_position,
                cast_type="string",
                depth=depth,
            )
            source_column = target_column = pre_calculated_config[
                consts.CONFIG_FIELD_ALIAS
            ]
            depth = 1
            calc_func = consts.CALC_FIELD_LENGTH
        elif column_type in ["string", "!string"]:
            calc_func = consts.CALC_FIELD_LENGTH

        elif self._is_uuid(column_type, target_column_type):
            calc_func = consts.CONFIG_CAST
            cast_type = consts.CONFIG_CAST_UUID_STRING

        elif column_type in ["binary", "!binary"]:
            calc_func = consts.CALC_FIELD_BYTE_LENGTH

        elif column_type in ["timestamp", "!timestamp", "date", "!date"]:
            if (
                self.source_client.name == "bigquery"
                or self.target_client.name == "bigquery"
            ):
                pre_calculated_config = self.build_and_append_pre_agg_calc_config(
                    source_column,
                    target_column,
                    consts.CONFIG_CAST,
                    column_position,
                    cast_type="timestamp",
                    depth=depth,
                )
                source_column = target_column = pre_calculated_config[
                    consts.CONFIG_FIELD_ALIAS
                ]
                depth = 1

            calc_func = consts.CALC_FIELD_EPOCH_SECONDS
            if agg_type == consts.CONFIG_TYPE_SUM:
                # It is possible to exceed int64 when summing epoch_seconds therefore cast to string.
                # See issue 1391 for details.
                final_cast_type = "string"

        elif column_type == "int32" or column_type == "!int32":
            calc_func = consts.CONFIG_CAST
            cast_type = "int64"

        else:
            raise ValueError(f"Unsupported column type: {column_type}")

        calculated_config = self.build_and_append_pre_agg_calc_config(
            source_column,
            target_column,
            calc_func,
            column_position,
            cast_type=cast_type,
            depth=depth,
        )

        aggregate_config = {
            consts.CONFIG_SOURCE_COLUMN: f"{calculated_config[consts.CONFIG_FIELD_ALIAS]}",
            consts.CONFIG_TARGET_COLUMN: f"{calculated_config[consts.CONFIG_FIELD_ALIAS]}",
            consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name(
                calculated_config[consts.CONFIG_FIELD_ALIAS],
                f"{agg_type}",
                column_position,
            ),
            consts.CONFIG_TYPE: agg_type,
        }
        if final_cast_type:
            # Adding to dict this way to avoid adding a lot of empty cast attributes.
            aggregate_config[consts.CONFIG_CAST] = final_cast_type

        return aggregate_config

    def _decimal_column_too_big_for_pandas(
        self,
        source_column_ibis_type: dt.DataType,
        target_column_ibis_type: dt.DataType,
        margin: int = 0,
    ) -> bool:
        """Identifies numeric columns that will cause problems in a Pandas Dataframe.

        i.e. are of greater precision than a 64bit int/real can hold.

        margin: Allows us to lower the precision threshold. This is helpful when summing column
                values that are okay by themselves but cumulativaly could overflow a 64bit value.
        """
        return bool(
            (
                (isinstance(source_column_ibis_type, dt.Int64) and margin > 0)
                or (
                    isinstance(source_column_ibis_type, dt.Decimal)
                    and (
                        source_column_ibis_type.precision is None
                        or source_column_ibis_type.precision > (18 - margin)
                    )
                )
            )
            and (
                (isinstance(target_column_ibis_type, dt.Int64) and margin > 0)
                or (
                    isinstance(target_column_ibis_type, dt.Decimal)
                    and (
                        target_column_ibis_type.precision is None
                        or target_column_ibis_type.precision > (18 - margin)
                    )
                )
            )
        )

    def _key_column_needs_casting_to_string(
        self,
        source_column_ibis_type: dt.DataType,
        target_column_ibis_type: dt.DataType,
    ) -> str:
        """Return a string cast if the datatype combination requires it, otherwise None."""
        if self._is_uuid(source_column_ibis_type, target_column_ibis_type):
            # This needs to come before binary check because Oracle
            # stores UUIDs (GUID) in binary columns.
            return consts.CONFIG_CAST_UUID_STRING
        elif (
            self._decimal_column_too_big_for_pandas(
                source_column_ibis_type, target_column_ibis_type
            )
            or isinstance(source_column_ibis_type, dt.Binary)
            or isinstance(target_column_ibis_type, dt.Binary)
        ):
            return "string"
        else:
            return None

    def _type_is_supported_for_agg_validation(
        self, source_type: str, target_type: str, supported_types: list
    ) -> bool:
        if self._is_uuid(source_type, target_type):
            return bool("uuid" in supported_types)
        return bool(source_type in supported_types and target_type in supported_types)

    def build_config_column_aggregates(
        self, agg_type, arg_value, exclude_cols, supported_types, cast_to_bigint=False
    ):
        """Return list of aggregate objects of given agg_type."""

        def require_pre_agg_calc_field(
            column_type: str,
            target_column_type: str,
            agg_type: str,
            cast_to_bigint: bool,
        ) -> bool:
            if all(
                _ in ["string", "!string", "json", "!json"]
                for _ in [column_type, target_column_type]
            ):
                # These data types are aggregated using their lengths.
                return True
            elif self._is_uuid(column_type, target_column_type):
                return True
            elif column_type in ["binary", "!binary"]:
                if agg_type == "count":
                    # Oracle BLOB is invalid for use with SQL COUNT function.
                    # The expression below returns True if client is Oracle which
                    # has the effect of triggering use of byte_length transformation.
                    return bool(
                        self.source_client.name == "oracle"
                        or self.target_client.name == "oracle"
                    )
                else:
                    # Convert to length for any min/max/sum on binary columns.
                    return True
            elif cast_to_bigint and column_type in ["int32", "!int32"]:
                return True
            elif column_type in [
                "timestamp",
                "!timestamp",
                "date",
                "!date",
            ] and agg_type in (
                "sum",
                "avg",
                "bit_xor",
            ):
                # For timestamps: do not convert to epoch seconds for min/max
                return True
            return False

        aggregate_configs = []
        source_table = self.get_source_ibis_calculated_table()
        target_table = self.get_target_ibis_calculated_table()

        casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns}
        casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns}

        if arg_value:
            arg_value = [x.casefold() for x in arg_value]
            if exclude_cols:
                included_cols = [
                    column
                    for column in casefold_source_columns
                    if column not in arg_value
                ]
                arg_value = included_cols

            if supported_types:
                # This mutates external supported_types, making it local as part of adding more values.
                supported_types = supported_types + [
                    "string",
                    "!string",
                    "timestamp",
                    "!timestamp",
                    "date",
                    "!date",
                    "binary",
                    "!binary",
                ]
        else:
            if exclude_cols:
                raise ValueError(
                    "Exclude columns flag cannot be present with '*' column aggregation"
                )

        allowlist_columns = arg_value or casefold_source_columns
        for column_position, column in enumerate(casefold_source_columns):
            # Get column type and remove precision/scale attributes
            source_column_ibis_type = source_table[
                casefold_source_columns[column]
            ].type()
            column_type = str(source_column_ibis_type).split("(")[0]
            target_column_ibis_type = target_table[
                casefold_target_columns[column]
            ].type()
            target_column_type = str(target_column_ibis_type).split("(")[0]

            if column not in allowlist_columns:
                continue
            elif column not in casefold_target_columns:
                logging.warning(
                    f"Skipping {agg_type} on {column} as column is not present in target table"
                )
                continue
            elif supported_types and not self._type_is_supported_for_agg_validation(
                column_type, target_column_type, supported_types
            ):
                if self.verbose:
                    logging.info(
                        f"Skipping {agg_type} on {column} due to data type: {column_type}"
                    )
                continue

            if require_pre_agg_calc_field(
                column_type, target_column_type, agg_type, cast_to_bigint
            ):
                aggregate_config = self.append_pre_agg_calc_field(
                    casefold_source_columns[column],
                    casefold_target_columns[column],
                    agg_type,
                    column_type,
                    target_column_type,
                    column_position,
                )
            else:
                aggregate_config = {
                    consts.CONFIG_SOURCE_COLUMN: casefold_source_columns[column],
                    consts.CONFIG_TARGET_COLUMN: casefold_target_columns[column],
                    consts.CONFIG_FIELD_ALIAS: self._prefix_calc_col_name(
                        column, f"{agg_type}", column_position
                    ),
                    consts.CONFIG_TYPE: agg_type,
                }
                if self._decimal_column_too_big_for_pandas(
                    source_column_ibis_type,
                    target_column_ibis_type,
                    margin=(2 if agg_type == consts.CONFIG_TYPE_SUM else 0),
                ):
                    aggregate_config[consts.CONFIG_CAST] = "string"

            aggregate_configs.append(aggregate_config)

        return aggregate_configs

    def build_config_calculated_fields(
        self,
        source_reference: list,
        target_reference: list,
        calc_type: str,
        alias: str,
        depth: int,
        custom_params: Optional[dict] = None,
    ) -> dict:
        """Returns list of calculated fields"""
        calculated_config = {
            consts.CONFIG_CALCULATED_SOURCE_COLUMNS: source_reference,
            consts.CONFIG_CALCULATED_TARGET_COLUMNS: target_reference,
            consts.CONFIG_FIELD_ALIAS: alias,
            consts.CONFIG_TYPE: calc_type,
            consts.CONFIG_DEPTH: depth,
        }

        if calc_type == consts.CONFIG_CUSTOM and custom_params:
            calculated_config.update(custom_params)
        elif calc_type == consts.CONFIG_CAST and custom_params:
            calculated_config[consts.CONFIG_DEFAULT_CAST] = custom_params

        return calculated_config

    def _get_comparison_max_col_length(self) -> int:
        if not self._comparison_max_col_length:
            self._comparison_max_col_length = min(
                [
                    clients.get_max_column_length(self.source_client),
                    clients.get_max_column_length(self.target_client),
                ]
            )
        return self._comparison_max_col_length

    def _strftime_format(
        self, column_type: Union[dt.Date, dt.Timestamp], client
    ) -> str:
        if column_type.is_timestamp():
            return "%Y-%m-%d %H:%M:%S"
        if clients.is_oracle_client(client):
            # Oracle DATE is a DateTime
            return "%Y-%m-%d %H:%M:%S"
        return "%Y-%m-%d"

    def _apply_base_cast_overrides(
        self,
        source_column: str,
        target_column: str,
        col_config: dict,
        source_table: "ibis.expr.types.Table",
        target_table: "ibis.expr.types.Table",
    ) -> dict:
        """Mutates col_config to contain any overrides. Also returns col_config for convenience."""
        if col_config["calc_type"] != consts.CALC_FIELD_CAST:
            return col_config

        source_table_schema = {k: v for k, v in source_table.schema().items()}
        target_table_schema = {k: v for k, v in target_table.schema().items()}

        if isinstance(
            source_table_schema[source_column], (dt.Date, dt.Timestamp)
        ) and isinstance(target_table_schema[target_column], (dt.Date, dt.Timestamp)):
            # Use strftime rather than cast for temporal comparisons.
            # Pick the most permissive format across the two engines.
            # For example Date -> Timestamp should format both source and target as Date.
            fmt = min(
                [
                    self._strftime_format(
                        source_table_schema[source_column], self.source_client
                    ),
                    self._strftime_format(
                        source_table_schema[source_column], self.target_client
                    ),
                ],
                key=len,
            )
            col_config["calc_type"] = consts.CONFIG_CUSTOM
            custom_params = {
                "calc_params": {
                    consts.CONFIG_CUSTOM_IBIS_EXPR: "ibis.expr.types.TemporalValue.strftime",
                    consts.CONFIG_CUSTOM_PARAMS: [
                        {consts.CONFIG_CUSTOM_PARAM_FORMAT_STR: fmt}
                    ],
                }
            }
            col_config.update(custom_params)
        elif isinstance(source_table_schema[source_column], dt.Boolean) or isinstance(
            target_table_schema[target_column], dt.Boolean
        ):
            custom_params = {"calc_params": consts.CONFIG_CAST_BOOL_STRING}
            col_config.update(custom_params)
        elif self._is_uuid(
            source_table_schema[source_column], target_table_schema[target_column]
        ):
            custom_params = {"calc_params": consts.CONFIG_CAST_UUID_STRING}
            col_config.update(custom_params)

        return col_config

    def _get_order_of_operations(self, calc_type: str) -> List[str]:
        """Return order of operations for row validation."""
        order_of_operations = [
            consts.CALC_FIELD_CAST,
            consts.CALC_FIELD_IFNULL,
            consts.CALC_FIELD_RSTRIP,
        ]
        if self.case_insensitive_match():
            order_of_operations.append(consts.CALC_FIELD_UPPER)
        if calc_type == consts.CALC_FIELD_HASH:
            order_of_operations.extend(
                [consts.CALC_FIELD_CONCAT, consts.CALC_FIELD_HASH]
            )
        elif calc_type == consts.CALC_FIELD_CONCAT:
            order_of_operations.append(consts.CALC_FIELD_CONCAT)

        return order_of_operations

    def _filter_columns_by_column_list(
        self, casefold_columns: list, col_list: list, exclude_cols: bool
    ) -> list:
        if col_list:
            filter_list = [_.casefold() for _ in col_list]
            if exclude_cols:
                # Exclude columns based on col_list if provided
                casefold_columns = {
                    k: v for (k, v) in casefold_columns.items() if k not in filter_list
                }
            else:
                # Include columns based on col_list if provided
                casefold_columns = {
                    k: v for (k, v) in casefold_columns.items() if k in filter_list
                }
        elif exclude_cols:
            raise ValueError(
                "Exclude columns flag cannot be present with column list '*'"
            )
        return casefold_columns

    def build_dependent_aliases(
        self, calc_type: str, col_list=None, exclude_cols=False
    ) -> List[Dict]:
        """This is a utility function for determining the required depth of all fields"""
        source_table = self.get_source_ibis_calculated_table()
        target_table = self.get_target_ibis_calculated_table()

        casefold_source_columns = {x.casefold(): str(x) for x in source_table.columns}
        casefold_target_columns = {x.casefold(): str(x) for x in target_table.columns}

        casefold_source_columns = self._filter_columns_by_column_list(
            casefold_source_columns, col_list, exclude_cols
        )
        casefold_target_columns = self._filter_columns_by_column_list(
            casefold_target_columns, col_list, exclude_cols
        )

        column_aliases = {}
        col_names = []
        for i, calc in enumerate(self._get_order_of_operations(calc_type)):
            if i == 0:
                previous_level = [x for x in casefold_source_columns.keys()]
            else:
                previous_level = [k for k, v in column_aliases.items() if v == i - 1]
            if calc in [consts.CALC_FIELD_CONCAT, consts.CALC_FIELD_HASH]:
                col = {}
                col["source_reference"] = previous_level
                col["target_reference"] = previous_level
                col["name"] = f"{calc}__all"
                col["calc_type"] = calc
                col["depth"] = i
                name = col["name"]
                # need to capture all aliases at the previous level. probably name concat__all
                column_aliases[name] = i
                col_names.append(col)
            else:
                # This needs to be the previous manifest of columns
                for j, column in enumerate(previous_level):
                    col = {}
                    col["source_reference"] = [column]
                    col["target_reference"] = [column]
                    col["name"] = self._prefix_calc_col_name(column, calc, j)
                    col["calc_type"] = calc
                    col["depth"] = i

                    if i == 0:
                        # If depth 0, get raw column name with correct casing
                        source_column = casefold_source_columns[column]
                        target_column = casefold_target_columns[column]
                        col["source_reference"] = [source_column]
                        col["target_reference"] = [target_column]
                        # If we are casting the base column (i == 0) then apply any
                        # datatype specific overrides.
                        col = self._apply_base_cast_overrides(
                            source_column,
                            target_column,
                            col,
                            source_table,
                            target_table,
                        )

                    name = col["name"]
                    column_aliases[name] = i
                    col_names.append(col)
        return col_names

    def build_comp_fields(self, col_list: list, exclude_cols: bool = False) -> dict:
        """This is a utility function processing comp-fields values like we do for hash/concat."""
        source_table = self.get_source_ibis_calculated_table()
        casefold_source_columns = {_.casefold(): str(_) for _ in source_table.columns}

        casefold_source_columns = self._filter_columns_by_column_list(
            casefold_source_columns, col_list, exclude_cols
        )

        return casefold_source_columns

    def auto_list_primary_keys(self) -> list:
        """Returns a list of primary key columns based on the source/target table.

        If neither source nor target systems have a primary key defined then [] is returned.
        """
        assert (
            self.validation_type != consts.CUSTOM_QUERY
        ), "Custom query validations should not be able to reach this method"
        primary_keys = self.source_client.list_primary_key_columns(
            self.source_schema, self.source_table
        )
        if not primary_keys:
            primary_keys = self.target_client.list_primary_key_columns(
                self.target_schema, self.target_table
            )
        return primary_keys or []
