# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import warnings
from concurrent.futures import ThreadPoolExecutor
import ibis.backends.pandas
import pandas
import uuid

from data_validation import combiner, consts, metadata, util
from data_validation.config_manager import ConfigManager
from data_validation.query_builder.random_row_builder import RandomRowBuilder
from data_validation.schema_validation import SchemaValidation
from data_validation.validation_builder import ValidationBuilder

""" The DataValidation class is where the code becomes source/target aware

    The class builds specific source and target clients and is likely where someone would go to
    customize their validation process.

    data_validator = DataValidation(builder, source_config, target_config, result_handler=None, verbose=False)
"""


class DataValidation(object):
    def __init__(
        self,
        config,
        validation_builder=None,
        schema_validator=None,
        result_handler=None,
        verbose=False,
        source_client: ibis.backends.base.BaseBackend = None,
        target_client: ibis.backends.base.BaseBackend = None,
    ):
        """Initialize a DataValidation client

        Args:
            config (dict): The validation config used for the comparison.
            validation_builder (ValidationBuilder): Optional instance of a ValidationBuilder.
            schema_validator (SchemaValidation): Optional instance of a SchemaValidation.
            result_handler (ResultHandler): Optional instance of as ResultHandler client.
            verbose (bool): If verbose, the Data Validation client will print the queries run.
            source_client: Optional client to avoid unnecessary connections,
            target_client: Optional client to avoid unnecessary connections,
        """
        self.verbose = verbose
        self._fresh_connections = not bool(source_client and target_client)

        # Data Client Management
        self.config = config

        self.config_manager = ConfigManager(
            config,
            source_client=source_client,
            target_client=target_client,
            verbose=self.verbose,
        )

        self.run_metadata = metadata.RunMetadata()
        self.run_metadata.labels = self.config_manager.labels

        # Use a generated uuid for the run_id if None was supplied via config
        self.run_metadata.run_id = self.config_manager.run_id or str(uuid.uuid4())

        # Initialize Validation Builder if None was supplied
        self.validation_builder = validation_builder or ValidationBuilder(
            self.config_manager
        )

        self.schema_validator = schema_validator or SchemaValidation(
            self.config_manager, run_metadata=self.run_metadata, verbose=self.verbose
        )

        # Initialize the default Result Handler if None was supplied
        self.result_handler = result_handler or self.config_manager.get_result_handler()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        if self._fresh_connections and hasattr(self, "config_manager"):
            self.config_manager.close_client_connections()

    # TODO(dhercher) we planned on shifting this to use an Execution Handler.
    # Leaving to to swast on the design of how this should look.
    def execute(self):
        """Execute Queries and Store Results"""
        # Apply random row filter before validations run
        if self.config_manager.use_random_rows():
            util.timed_call("Random row filter", self._add_random_row_filter)

        # Run correct execution for the given validation type
        if self.config_manager.validation_type == consts.ROW_VALIDATION:
            grouped_fields = self.validation_builder.pop_grouped_fields()
            result_df = self.execute_recursive_validation(
                self.validation_builder, grouped_fields
            )
        elif self.config_manager.validation_type == consts.SCHEMA_VALIDATION:
            """Perform only schema validation"""
            result_df = util.timed_call(
                "Schema validation", self.schema_validator.execute
            )
        else:
            result_df = self._execute_validation(self.validation_builder)

        # Call Result Handler to Manage Results
        return self.result_handler.execute(result_df)

    def _add_random_row_filter(self):
        """Add random row filters to the validation builder."""
        if not self.config_manager.primary_keys:
            raise ValueError("Primary Keys are required for Random Row Filters")

        # Filter for only first primary key (multi-pk filter not supported)
        source_pk_column = self.config_manager.primary_keys[0][
            consts.CONFIG_SOURCE_COLUMN
        ]
        target_pk_column = self.config_manager.primary_keys[0][
            consts.CONFIG_TARGET_COLUMN
        ]

        randomRowBuilder = RandomRowBuilder(
            [source_pk_column],
            self.config_manager.random_row_batch_size(),
        )

        if (self.config_manager.validation_type == consts.CUSTOM_QUERY) and (
            self.config_manager.custom_query_type == consts.ROW_VALIDATION.lower()
        ):
            query = randomRowBuilder.compile_custom_query(
                self.config_manager.source_client,
                self.config_manager.source_query,
            )
        else:
            query = randomRowBuilder.compile(
                self.config_manager.source_client,
                self.config_manager.source_schema,
                self.config_manager.source_table,
                self.validation_builder.source_builder,
            )

        # Check if source table's primary key is BINARY, if so then
        # force cast the id columns to STRING (HEX).
        binary_conversion_required = False
        if query[source_pk_column].type().is_binary():
            binary_conversion_required = True
            query = query.mutate(
                **{source_pk_column: query[source_pk_column].cast("string")}
            )

        if self.config_manager.trim_string_pks():
            query = query.mutate(**{source_pk_column: query[source_pk_column].rstrip()})

        random_rows = self.config_manager.source_client.execute(query)
        if len(random_rows) == 0:
            return

        random_values = list(random_rows[source_pk_column])
        if binary_conversion_required:
            # For binary ids we have a list of hex strings for our IN list.
            # Each of these needs to be cast back to binary.
            random_values = [ibis.literal(_).cast("binary") for _ in random_values]

        filter_field = {
            consts.CONFIG_TYPE: consts.FILTER_TYPE_ISIN,
            consts.CONFIG_FILTER_SOURCE_COLUMN: source_pk_column,
            consts.CONFIG_FILTER_SOURCE_VALUE: random_values,
            consts.CONFIG_FILTER_TARGET_COLUMN: target_pk_column,
            consts.CONFIG_FILTER_TARGET_VALUE: random_values,
        }

        self.validation_builder.add_filter(filter_field)

    def query_too_large(self, rows_df, grouped_fields):
        """Return bool to dictate if another level of recursion
        would create a too large result set.

        Rules to define too large are:
            - If any grouped fields remain, return False.
                (assumes user added logical sized groups)
            - Else, if next group size is larger
                than the limit, return True.
            - Finally return False if no covered case occured.
        """
        if len(grouped_fields) > 1:
            return False

        try:
            count_df = rows_df[
                rows_df[consts.AGGREGATION_TYPE] == consts.CONFIG_TYPE_COUNT
            ]
            for row in count_df.to_dict(orient="row"):
                recursive_query_size = max(
                    float(row[consts.SOURCE_AGG_VALUE]),
                    float(row[consts.TARGET_AGG_VALUE]),
                )
                if recursive_query_size > self.config_manager.max_recursive_query_size:
                    logging.warning("Query result is too large for recursion: %s", row)
                    return True
        except Exception:
            logging.warning("Recursive values could not be cast to float.")
            return False

        return False

    def execute_recursive_validation(self, validation_builder, grouped_fields):
        """Recursive execution for Row validations.

        This method executes aggregate queries, such as sum-of-hashes, on the
        source and target tables. Where they differ, add to the GROUP BY
        clause recursively until the individual row differences can be
        identified.
        """
        past_results = []
        if len(grouped_fields) > 0:
            validation_builder.add_query_group(grouped_fields[0])
            result_df = self._execute_validation(validation_builder)

            for grouped_key in result_df[consts.GROUP_BY_COLUMNS].unique():
                # Validations are viewed separtely, but queried together.
                # We must treat them as a single item which failed or succeeded.
                group_suceeded = True
                grouped_key_df = result_df[
                    result_df[consts.GROUP_BY_COLUMNS] == grouped_key
                ]

                if self.query_too_large(grouped_key_df, grouped_fields):
                    past_results.append(grouped_key_df)
                    continue

                for row in grouped_key_df.to_dict(orient="row"):
                    if row[consts.SOURCE_AGG_VALUE] == row[consts.TARGET_AGG_VALUE]:
                        continue
                    else:
                        group_suceeded = False
                        break

                if group_suceeded:
                    past_results.append(grouped_key_df)
                else:
                    recursive_validation_builder = validation_builder.clone()
                    self._add_recursive_validation_filter(
                        recursive_validation_builder, row
                    )
                    past_results.append(
                        self.execute_recursive_validation(
                            recursive_validation_builder, grouped_fields[1:]
                        )
                    )
        elif self.config_manager.primary_keys and len(grouped_fields) == 0:
            past_results.append(self._execute_validation(validation_builder))

        # elif self.config_manager.primary_keys:
        #     validation_builder.add_config_query_groups(self.config_manager.primary_keys)
        #     validation_builder.add_config_query_groups(grouped_fields)

        else:
            warnings.warn(
                "WARNING: No Primary Keys Suppplied in Row Validation", UserWarning
            )
            return None

        return pandas.concat(past_results)

    def _add_recursive_validation_filter(self, validation_builder, row):
        """Return ValidationBuilder Configured for Next Recursive Search"""
        group_by_columns = json.loads(row[consts.GROUP_BY_COLUMNS])
        for alias, value in group_by_columns.items():
            filter_field = {
                consts.CONFIG_TYPE: consts.FILTER_TYPE_EQUALS,
                consts.CONFIG_FILTER_SOURCE_COLUMN: validation_builder.get_grouped_alias_source_column(
                    alias
                ),
                consts.CONFIG_FILTER_SOURCE_VALUE: value,
                consts.CONFIG_FILTER_TARGET_COLUMN: validation_builder.get_grouped_alias_target_column(
                    alias
                ),
                consts.CONFIG_FILTER_TARGET_VALUE: value,
            }
            validation_builder.add_filter(filter_field)

    def _execute_validation(self, validation_builder):
        """Execute Against a Supplied Validation Builder"""

        self.run_metadata.validations = validation_builder.get_metadata()

        source_query = validation_builder.get_source_query()
        target_query = validation_builder.get_target_query()

        join_on_fields = (
            set(validation_builder.get_primary_keys())
            if (self.config_manager.validation_type == consts.ROW_VALIDATION)
            or (
                self.config_manager.validation_type == consts.CUSTOM_QUERY
                and self.config_manager.custom_query_type == "row"
            )
            else set(validation_builder.get_group_aliases())
        )

        # If row validation from YAML, compare source and target agg values
        is_value_comparison = (
            self.config_manager.validation_type == consts.ROW_VALIDATION
            or (
                self.config_manager.validation_type == consts.CUSTOM_QUERY
                and self.config_manager.custom_query_type == "row"
            )
        )

        futures = []
        with ThreadPoolExecutor() as executor:
            # Submit the two query network calls concurrently
            futures.append(
                executor.submit(
                    util.timed_call,
                    "Source query",
                    self.config_manager.source_client.execute,
                    source_query,
                )
            )
            futures.append(
                executor.submit(
                    util.timed_call,
                    "Target query",
                    self.config_manager.target_client.execute,
                    target_query,
                )
            )
            source_df = futures[0].result()
            target_df = futures[1].result()

        try:
            result_df = util.timed_call(
                "Generate report",
                combiner.generate_report,
                self.run_metadata,
                source_df,
                target_df,
                join_on_fields=join_on_fields,
                is_value_comparison=is_value_comparison,
                verbose=self.verbose,
            )
        except Exception as e:
            if self.verbose:
                logging.error("-- ** Logging Source DF ** --")
                logging.error(source_df.dtypes)
                logging.error(source_df)
                logging.error("-- ** Logging Target DF ** --")
                logging.error(target_df.dtypes)
                logging.error(target_df)
            raise e

        return result_df

    def combine_data(self, source_df, target_df, join_on_fields):
        """TODO: Return List of Dictionaries"""
        # Clean Data to Standardize
        if join_on_fields:
            df = source_df.merge(
                target_df,
                how="outer",
                on=join_on_fields,
                suffixes=(consts.INPUT_SUFFIX, consts.OUTPUT_SUFFIX),
            )
        else:
            df = source_df.join(
                target_df,
                how="outer",
                lsuffix=consts.INPUT_SUFFIX,
                rsuffix=consts.OUTPUT_SUFFIX,
            )
        return df
