# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import itertools
import logging
import pandas
import re

from data_validation import metadata, consts, clients, exceptions


# Check for decimal data type with precision and/or scale. Permits hyphen in p/s for value ranges.
DECIMAL_PRECISION_SCALE_PATTERN = re.compile(
    r"([!]?decimal)\(([0-9\-]+)(?:,[ ]*([0-9\-]+))?\)", re.I
)
# Extract lower/upper from a range of the format "0-2" or "12-18".
DECIMAL_PRECISION_SCALE_RANGE_PATTERN = re.compile(
    r"([0-9]{2}|[0-9])(?:\-)([0-9]{2}|[0-9])"
)


class SchemaValidation(object):
    def __init__(self, config_manager, run_metadata=None, verbose=False):
        """Initialize a SchemaValidation client

        Args:
            config_manager (ConfigManager): The ConfigManager for the validation.
            run_metadata (RunMetadata): The RunMetadata for the validation.
            verbose (bool): If verbose, the Data Validation client will print the queries run
        """
        self.verbose = verbose
        self.config_manager = config_manager
        self.run_metadata = run_metadata or metadata.RunMetadata()

    def execute(self):
        """Performs a validation between source and a target schema"""
        ibis_source_schema = clients.get_ibis_table_schema(
            self.config_manager.source_client,
            self.config_manager.source_schema,
            self.config_manager.source_table,
        )
        ibis_target_schema = clients.get_ibis_table_schema(
            self.config_manager.target_client,
            self.config_manager.target_schema,
            self.config_manager.target_table,
        )

        source_fields = {}
        for field_name, data_type in ibis_source_schema.items():
            source_fields[field_name] = data_type
        target_fields = {}
        for field_name, data_type in ibis_target_schema.items():
            target_fields[field_name] = data_type
        results = schema_validation_matching(
            source_fields,
            target_fields,
            self.config_manager.exclusion_columns,
            self.config_manager.allow_list,
        )
        df = pandas.DataFrame(
            results,
            columns=[
                consts.SOURCE_COLUMN_NAME,
                consts.TARGET_COLUMN_NAME,
                consts.SOURCE_AGG_VALUE,
                consts.TARGET_AGG_VALUE,
                consts.VALIDATION_STATUS,
            ],
        )

        # Update and Assign Metadata Values
        self.run_metadata.end_time = datetime.datetime.now(datetime.timezone.utc)

        df.insert(loc=0, column=consts.CONFIG_RUN_ID, value=self.run_metadata.run_id)
        df.insert(loc=1, column=consts.VALIDATION_NAME, value="Schema")
        df.insert(loc=2, column=consts.VALIDATION_TYPE, value="Schema")

        df.insert(
            loc=3,
            column=consts.CONFIG_LABELS,
            value=[self.run_metadata.labels for _ in range(len(df.index))],
        )
        df.insert(
            loc=4, column=consts.CONFIG_START_TIME, value=self.run_metadata.start_time
        )
        df.insert(
            loc=5, column=consts.CONFIG_END_TIME, value=self.run_metadata.end_time
        )

        df.insert(
            loc=6,
            column=consts.SOURCE_TABLE_NAME,
            value=self.config_manager.full_source_table,
        )
        df.insert(
            loc=7,
            column=consts.TARGET_TABLE_NAME,
            value=self.config_manager.full_target_table,
        )
        df.insert(loc=10, column=consts.AGGREGATION_TYPE, value="Schema")

        # empty columns added due to changes on the results schema
        df.insert(loc=14, column=consts.CONFIG_PRIMARY_KEYS, value=None)
        df.insert(loc=15, column=consts.NUM_RANDOM_ROWS, value=None)
        df.insert(loc=16, column=consts.GROUP_BY_COLUMNS, value=None)
        df.insert(loc=17, column=consts.VALIDATION_DIFFERENCE, value=None)
        df.insert(loc=18, column=consts.VALIDATION_PCT_THRESHOLD, value=None)

        return df


def schema_validation_matching(
    source_fields, target_fields, exclusion_fields, allow_list
):
    """Compare schemas between two dictionary objects"""
    results = []
    # Apply the casefold() function to lowercase the keys of source and target
    source_fields_casefold = {
        source_field_name.casefold(): source_field_type
        for source_field_name, source_field_type in source_fields.items()
    }
    target_fields_casefold = {
        target_field_name.casefold(): target_field_type
        for target_field_name, target_field_type in target_fields.items()
    }

    if exclusion_fields is not None:
        for field in exclusion_fields:
            source_fields_casefold.pop(field, None)
            target_fields_casefold.pop(field, None)

    # Allow list map in case of incompatible  data types in source and target
    allow_list_map = parse_allow_list(allow_list)
    # Go through each source and check if target exists and matches
    for source_field_name, source_field_type in source_fields_casefold.items():
        if source_field_name not in target_fields_casefold:
            # Target field doesn't exist
            results.append(
                [
                    source_field_name,
                    "N/A",
                    str(source_field_type),
                    "N/A",
                    consts.VALIDATION_STATUS_FAIL,
                ]
            )
            continue

        target_field_type = target_fields_casefold[source_field_name]
        if source_field_type == target_field_type:
            # Target data type matches
            results.append(
                [
                    source_field_name,
                    source_field_name,
                    str(source_field_type),
                    str(target_field_type),
                    consts.VALIDATION_STATUS_SUCCESS,
                ]
            )
        elif (
            string_val(source_field_type) in allow_list_map
            and string_val(target_field_type)
            in allow_list_map[string_val(source_field_type)]
        ):
            # Data type pair match an allow-list pair.
            results.append(
                [
                    source_field_name,
                    source_field_name,
                    string_val(source_field_type),
                    str(target_field_type),
                    consts.VALIDATION_STATUS_SUCCESS,
                ]
            )
        else:
            # Target data type mismatch
            (higher_precision, lower_precision,) = parse_n_validate_datatypes(
                string_val(source_field_type), string_val(target_field_type)
            )
            if higher_precision:
                # If the target precision is higher then the validation is acceptable but worth a warning.
                logging.warning(
                    "Source and target data type has precision mismatch: %s - %s",
                    string_val(source_field_type),
                    str(target_field_type),
                )
                results.append(
                    [
                        source_field_name,
                        source_field_name,
                        string_val(source_field_type),
                        str(target_field_type),
                        consts.VALIDATION_STATUS_SUCCESS,
                    ]
                )
            else:
                results.append(
                    [
                        source_field_name,
                        source_field_name,
                        str(source_field_type),
                        str(target_field_type),
                        consts.VALIDATION_STATUS_FAIL,
                    ]
                )

    # Source field doesn't exist
    for target_field_name, target_field_type in target_fields_casefold.items():
        if target_field_name not in source_fields_casefold:
            results.append(
                [
                    "N/A",
                    target_field_name,
                    "N/A",
                    str(target_field_type),
                    consts.VALIDATION_STATUS_FAIL,
                ]
            )
    return results


def split_allow_list_str(allow_list_str: str) -> list:
    """Split the allow list string into a list of datatype:datatype tuples."""
    # I've not moved this patter to a compiled constant because it should only
    # happen once per command and I felt splitting the pattern into variables
    # aided readability.
    nullable_pattern = r"!?"
    precision_scale_pattern = r"(?:\((?:[0-9 ,\-]+|'UTC')\))?"
    data_type_pattern = nullable_pattern + r"[a-z0-9 ]+" + precision_scale_pattern
    csv_split_pattern = data_type_pattern + r":" + data_type_pattern
    data_type_pairs = [
        _.replace(" ", "").split(":")
        for _ in re.findall(csv_split_pattern, allow_list_str, re.I)
    ]
    invalid_pairs = [_ for _ in data_type_pairs if len(_) != 2]
    if invalid_pairs:
        raise exceptions.SchemaValidationException(
            f"Invalid data type pairs: {invalid_pairs}"
        )
    return data_type_pairs


def expand_precision_range(s: str) -> list:
    """Expand an integer range (e.g. "0-3") to a list (e.g. ["0", "1", "2", "3"])."""
    m_range = DECIMAL_PRECISION_SCALE_RANGE_PATTERN.match(s)
    if not m_range:
        return [s]
    try:
        p_lower = int(m_range.group(1))
        p_upper = int(m_range.group(2))
        if p_lower >= p_upper:
            raise exceptions.SchemaValidationException(
                f"Invalid allow list data type precision/scale: Lower value {p_lower} >= upper value {p_upper}"
            )
        return [str(_) for _ in range(p_lower, p_upper + 1)]
    except ValueError as e:
        raise exceptions.SchemaValidationException(
            f"Invalid allow list data type precision/scale: {s}"
        ) from e


def expand_precision_or_scale_range(data_type: str) -> list:
    """Take a data type and example any precision/scale range.

    For example "decimal(1-3,0)" becomes:
      ["decimal(1,0)", "decimal(2,0)", "decimal(3,0)"]"""

    m = DECIMAL_PRECISION_SCALE_PATTERN.match(data_type.replace(" ", ""))
    if not m:
        return [data_type]

    if len(m.groups()) != 3:
        raise exceptions.SchemaValidationException(
            f"Badly formatted data type: {data_type}"
        )

    type_name, p, s = m.groups()
    p_list = expand_precision_range(p)
    if s:
        s_list = expand_precision_range(s)
        return_list = [
            f"{type_name}({p},{s})" for p, s in itertools.product(p_list, s_list)
        ]
    else:
        return_list = [f"{type_name}({_})" for _ in p_list]
    return return_list


def parse_allow_list(st: str) -> dict:
    """Convert allow-list data type pairs into a dictionary like {key[value1, value2, etc], }"""

    def expand_allow_list_ranges(data_type_pairs: list) -> list:
        expanded_pairs = []
        for dt1, dt2 in data_type_pairs:
            dt1_list = expand_precision_or_scale_range(dt1)
            dt2_list = expand_precision_or_scale_range(dt2)
            expanded_pairs.extend(
                [(_[0], _[1]) for _ in itertools.product(dt1_list, dt2_list)]
            )
        return expanded_pairs

    def convert_pairs_to_dict(expanded_pairs: list) -> dict:
        """Take the list data type tuples and convert them into a dictionary keyed on source data type.
        For example:
            [('decimal(2,0)', 'int64'), ('decimal(2,0)', 'int32')]
        becomes:
            {'decimal(2,0)': ['int64', 'int32']}
        """
        return_pairs = {}
        for dt1, dt2 in expanded_pairs:
            if dt1 in return_pairs:
                return_pairs[dt1].append(dt2)
            else:
                return_pairs[dt1] = [dt2]
        return return_pairs

    data_type_pairs = split_allow_list_str(st)
    expanded_pairs = expand_allow_list_ranges(data_type_pairs)
    return_pairs = convert_pairs_to_dict(expanded_pairs)
    return return_pairs


# typea data types: int8,int16
def get_typea_numeric_sustr(st):
    nums = []
    if "(" in st:
        return -1
    for i in range(len(st)):
        if st[i].isdigit():
            nums.append(st[i])
    num = "".join(nums)
    if num == "":
        return -1
    return int(num)


# typeb data types: Decimal(10,2)
def get_typeb_numeric_sustr(st: str) -> tuple:
    m = DECIMAL_PRECISION_SCALE_PATTERN.match(st.replace(" ", ""))
    if not m:
        return -1, -1
    _, p, s = m.groups()
    if s is None:
        s = 0
    return int(p), int(s)


def string_val(st):
    return str(st).replace(" ", "")


def validate_typeb_vals(source, target):
    if source[0] > target[0] or source[1] > target[1]:
        return False, True
    elif source[0] == target[0] and source[1] == target[1]:
        return False, False
    return True, False


def strip_null(st):
    return st.replace("!", "")


def parse_n_validate_datatypes(source, target) -> tuple:
    """
    Args:
        source: Source table datatype string
        target: Target table datatype string
    Returns:
        bool:target has higher precision value
        bool:target has lower precision value
    """
    if strip_null(source) == target:
        return False, False
    if "(" in source and "(" in target:
        typeb_source = get_typeb_numeric_sustr(source)
        typeb_target = get_typeb_numeric_sustr(target)
        higher_precision, lower_precision = validate_typeb_vals(
            typeb_source, typeb_target
        )
        return higher_precision, lower_precision
    source_num = get_typea_numeric_sustr(source)
    target_num = get_typea_numeric_sustr(target)
    # In case of no bits specified, we will not match for precisions
    if source_num == -1 or target_num == -1:
        return False, False
    if source_num == target_num:
        return False, False
    elif source_num > target_num:
        return False, True
    return False, False
