# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import ibis
from data_validation import consts
from ibis.expr.types import StringScalar
from third_party.ibis.ibis_addon import api, operations


class AggregateField(object):
    def __init__(self, ibis_expr, field_name=None, alias=None, cast=None):
        """A representation of a table or column aggregate in Ibis

        Args:
            ibis_expr (ColumnExpr): A column aggregation to use from Ibis
            field_name (String: A field to act on in the table.
                Table level expr do not have a field name
            alias (String): A field to use as the aggregate alias name
        """
        self.expr = ibis_expr
        self.field_name = field_name
        self.alias = alias
        self.cast = cast

    @staticmethod
    def count(field_name=None, alias=None, cast=None):
        if field_name:
            return AggregateField(
                ibis.expr.types.ColumnExpr.count,
                field_name=field_name,
                alias=alias,
                cast=cast,
            )
        else:
            return AggregateField(
                ibis.expr.types.TableExpr.count,
                field_name=field_name,
                alias=alias,
                cast=cast,
            )

    @staticmethod
    def min(field_name=None, alias=None, cast=None):
        return AggregateField(
            ibis.expr.types.ColumnExpr.min,
            field_name=field_name,
            alias=alias,
            cast=cast,
        )

    @staticmethod
    def avg(field_name=None, alias=None, cast=None):
        return AggregateField(
            ibis.expr.types.NumericColumn.mean,
            field_name=field_name,
            alias=alias,
            cast=cast,
        )

    @staticmethod
    def max(field_name=None, alias=None, cast=None):
        return AggregateField(
            ibis.expr.types.ColumnExpr.max,
            field_name=field_name,
            alias=alias,
            cast=cast,
        )

    @staticmethod
    def sum(field_name=None, alias=None, cast=None):
        return AggregateField(
            ibis.expr.api.IntegerColumn.sum,
            field_name=field_name,
            alias=alias,
            cast=cast,
        )

    @staticmethod
    def bit_xor(field_name=None, alias=None, cast=None):
        return AggregateField(
            ibis.expr.api.IntegerColumn.bit_xor,
            field_name=field_name,
            alias=alias,
            cast=cast,
        )

    @staticmethod
    def std(field_name=None, alias=None, cast=None):
        return AggregateField(
            ibis.expr.types.NumericColumn.std,
            field_name=field_name,
            alias=alias,
            cast=cast,
        )

    def compile(self, ibis_table):
        if self.field_name:
            agg_field = self.expr(ibis_table[self.field_name])
        else:
            agg_field = self.expr(ibis_table)

        if self.cast:
            agg_field = agg_field.force_cast(self.cast)

        if self.alias:
            agg_field = agg_field.name(self.alias)

        return agg_field


class FilterField(object):
    def __init__(
        self, ibis_expr, left=None, right=None, left_field=None, right_field=None
    ):
        """A representation of a query filter to be used while building a query.
            You can alternatively use either (left or left_field) and
            (right or right_field).

        Args:
            ibis_expr (ColumnExpr): A column expression to be used for comparisons (None represents a custom filter).
            left (Object): A value to compare on the left side of the expression
            left_field (String): A column name to be used to filter against
            right (Object): A value to compare on the right side of the expression
            right_field (String): A column name to be used to filter against

        """
        self.expr = ibis_expr
        self.left = left
        self.right = right
        self.left_field = left_field
        self.right_field = right_field

    @staticmethod
    def greater_than(field_name, value):
        # Build Left and Right Objects
        return FilterField(
            ibis.expr.types.ColumnExpr.__gt__, left_field=field_name, right=value
        )

    @staticmethod
    def less_than(field_name, value):
        # Build Left and Right Objects
        return FilterField(
            ibis.expr.types.ColumnExpr.__lt__, left_field=field_name, right=value
        )

    @staticmethod
    def equal_to(field_name, value):
        # Build Left and Right Objects
        return FilterField(
            ibis.expr.types.ColumnExpr.__eq__, left_field=field_name, right=value
        )

    @staticmethod
    def isin(field_name, values):
        # Build Left and Right Objects
        return FilterField(
            ibis.expr.types.ColumnExpr.isin, left_field=field_name, right=values
        )

    @staticmethod
    def custom(expr):
        """Returns a FilterField instance built for any custom SQL using a supported operator.

        Args:
            expr (Str): A custom SQL expression used to filter a query.
        """
        return FilterField(None, left=expr)

    @staticmethod
    def or_(field_list: list):
        return FilterField(ibis.or_, left=field_list)

    def compile(self, ibis_table):
        if self.expr is None:
            return operations.compile_raw_sql(ibis_table, self.left)

        if self.left_field:
            self.left = ibis_table[self.left_field]

        if self.right_field:
            self.right = ibis_table[self.right_field]

        if self.expr == ibis.or_:
            return self.expr(*[_.compile(ibis_table) for _ in self.left])
        else:
            return self.expr(self.left, self.right)


class ComparisonField(object):
    def __init__(
        self, field_name: str, alias: str = None, cast: str = None, trim: bool = None
    ):
        """A representation of a comparison field used to build a query.

        Args:
            field_name (String): A field to act on in the table
            alias (String): An alias to use for the group
            cast (String): A cast on the column if required
        """
        self.field_name = field_name
        self.alias = alias
        self.cast = cast
        self.trim = trim

    def compile(self, ibis_table):
        # Fields are supplied on compile or on build
        comparison_field = ibis_table[self.field_name]
        alias = self.alias or self.field_name
        if self.cast:
            comparison_field = comparison_field.force_cast(self.cast)
        elif self.trim and comparison_field.type().is_string():
            comparison_field = comparison_field.rstrip()
        comparison_field = comparison_field.name(alias)

        return comparison_field


class GroupedField(object):
    def __init__(self, field_name, alias=None, cast=None):
        """A representation of a group by field used to build a query.

        Args:
            field_name (String): A field to act on in the table
            alias (String): An alias to use for the group
            cast (String): A cast on the column if required
        """
        self.field_name = field_name
        self.alias = alias
        self.cast = cast

    def compile(self, ibis_table):
        # Fields are supplied on compile or on build
        group_field = ibis_table[self.field_name]

        # TODO: generate cast for known types not specified
        if self.cast:
            group_field = group_field.cast(self.cast)
        elif isinstance(group_field.type(), ibis.expr.datatypes.Timestamp):
            group_field = group_field.cast("date")
        else:
            # TODO: need to build Truncation Int support
            # TODO: should be using a logger
            logging.warning("Unknown cast types can cause memory errors")

        # The Casts require we also supply a name.
        alias = self.alias or self.field_name
        group_field = group_field.name(alias)

        return group_field


class ColumnReference(object):
    def __init__(self, column_name):
        """A representation of an calculated field to build a query.

        Args:
            column_name (String): The column name used in a complex expr
        """
        self.column_name = column_name

    def compile(self, ibis_table):
        """Return an ibis object referencing the column.

        Args:
            ibis_table (IbisTable): The table obj reference
        """
        return ibis_table[self.column_name]


class CalculatedField(object):
    def __init__(self, ibis_expr, config, fields, cast=None, **kwargs):
        """A representation of an calculated field to build a query.

        Args:
            config dict: Configurations object explaining calc field details
            fields list: List of fields to transform into a single column
        """
        self.expr = ibis_expr
        self.config = config
        self.fields = fields
        self.cast = cast
        self.kwargs = kwargs

    def __repr__(self):
        return (
            f"CalculatedField(fields={self.fields}, expr={self.expr}, cast={self.cast})"
        )

    @staticmethod
    def concat(config, fields):
        if config.get("default_concat_separator") is None:
            config["default_concat_separator"] = ibis.literal("")
        fields = [config["default_concat_separator"], fields]
        cast = "string"
        return CalculatedField(
            ibis.expr.types.StringValue.join,
            config,
            fields,
            cast=cast,
        )

    @staticmethod
    def hash(config, fields):
        if config.get("default_hash_function") is None:
            how = "sha256"
            return CalculatedField(
                ibis.expr.types.StringValue.hashbytes,
                config,
                fields,
                how=how,
            )
        else:
            how = "farm_fingerprint"
            return CalculatedField(
                ibis.expr.types.Value.hash,
                config,
                fields,
                how=how,
            )

    @staticmethod
    def to_char(config, fields):
        fmt = ibis.literal(config.get("default_to_char_fmt", "FM90.099"))
        return CalculatedField(
            ibis.expr.api.NumericValue.to_char, config, fields, fmt=fmt
        )

    @staticmethod
    def ifnull(config, fields):
        default_null_string = ibis.literal(
            config.get("default_null_string", "DEFAULT_REPLACEMENT_STRING")
        )
        fields = [fields[0], default_null_string]
        return CalculatedField(
            ibis.expr.types.Value.fillna,
            config,
            fields,
        )

    @staticmethod
    def length(config, fields):
        return CalculatedField(
            ibis.expr.types.StringValue.length,
            config,
            fields,
        )

    @staticmethod
    def padded_char_length(config, fields):
        return CalculatedField(
            ibis.expr.types.StringValue.padded_char_length,
            config,
            fields,
        )

    @staticmethod
    def byte_length(config, fields):
        return CalculatedField(
            ibis.expr.types.BinaryValue.byte_length,
            config,
            fields,
        )

    @staticmethod
    def rstrip(config, fields):
        return CalculatedField(
            ibis.expr.types.StringValue.rstrip,
            config,
            fields,
        )

    @staticmethod
    def upper(config, fields):
        return CalculatedField(
            ibis.expr.types.StringValue.upper,
            config,
            fields,
        )

    @staticmethod
    def epoch_seconds(config, fields):
        return CalculatedField(
            ibis.expr.types.TimestampValue.epoch_seconds,
            config,
            fields,
        )

    @staticmethod
    def cast(config, fields):
        target_type = config.get(consts.CONFIG_DEFAULT_CAST, "string")
        return CalculatedField(
            api.cast,
            config,
            fields,
            target_type=target_type,
        )

    @staticmethod
    def custom(config, fields):
        """Returns a CalculatedField instance built for any custom ibis expression
        e.g. 'ibis.expr.api.StringValue.replace'. For a list of supported functions,
        see https://github.com/ibis-project/ibis/blob/1.4.0/ibis/expr/api.py
        Args:
            expr (Str): A custom ibis expression to be used as a calc field
        """
        ibis_expr = config.get(consts.CONFIG_CUSTOM_IBIS_EXPR)
        expr_params = config.get(consts.CONFIG_CUSTOM_PARAMS, [])
        params = {k: v for d in expr_params for k, v in d.items()}
        return CalculatedField(eval(ibis_expr), config, fields, **params)

    def _compile_fields(self, ibis_table, fields):
        compiled_fields = []

        for field in fields:
            if type(field) in [StringScalar]:
                compiled_fields.append(field)
            elif isinstance(field, list):
                compiled_fields.append(self._compile_fields(ibis_table, field))
            else:
                if self.cast:
                    compiled_fields.append(ibis_table[field].cast(self.cast))
                else:
                    compiled_fields.append(ibis_table[field])
        return compiled_fields

    def compile(self, ibis_table):
        compiled_fields = self._compile_fields(ibis_table, self.fields)
        calc_field = self.expr(*compiled_fields, **self.kwargs)
        if self.config["field_alias"]:
            calc_field = calc_field.name(self.config["field_alias"])

        return calc_field


class QueryBuilder(object):
    def __init__(
        self,
        aggregate_fields,
        calculated_fields,
        filters,
        grouped_fields,
        comparison_fields,
        limit=None,
    ):
        """Build a QueryBuilder object which can be used to build queries easily

        Args:
            aggregate_fields (list[AggregateField]): AggregateField instances with Ibis expressions
            calculated_fields (list[CalculatedField]): A list of CalculatedField instances
            filters (list[FilterField]): A list of FilterField instances
            grouped_fields (list[GroupedField]): A list of GroupedField instances
            limit (int): A limit value for the number of records to pull (used for testing)
        """
        self.aggregate_fields = aggregate_fields
        self.calculated_fields = calculated_fields
        self.filters = filters
        self.grouped_fields = grouped_fields
        self.comparison_fields = comparison_fields
        self.limit = limit

    @staticmethod
    def build_count_validator(limit=None):
        """Return a basic template builder for most validations"""
        aggregate_fields = []
        filters = []
        grouped_fields = []
        comparison_fields = []
        calculated_fields = []

        return QueryBuilder(
            aggregate_fields,
            filters=filters,
            grouped_fields=grouped_fields,
            comparison_fields=comparison_fields,
            calculated_fields=calculated_fields,
        )

    def compile_aggregate_fields(self, table):
        aggs = [field.compile(table) for field in self.aggregate_fields]

        return aggs

    def compile_filter_fields(self, table):
        return [field.compile(table) for field in self.filters]

    def compile_group_fields(self, table):
        return [field.compile(table) for field in self.grouped_fields]

    def compile_comparison_fields(self, table):
        return [field.compile(table) for field in self.comparison_fields]

    def compile_calculated_fields(self, table, n=0):
        return [
            field.compile(table)
            for field in self.calculated_fields
            if field.config[consts.CONFIG_DEPTH] == n
        ]
        # if n is not None:
        #     return [
        #         field.compile(table)
        #         for field in self.calculated_fields
        #         if field.config[consts.CONFIG_DEPTH] == n
        #     ]
        # else:
        #     return [field.compile(table) for field in self.calculated_fields]

    def compile(self, validation_type, table):
        """Return an Ibis query object

        Args:
            table (IbisTable): The Ibis Table expression.
        """

        # Build Query Expressions
        compiled_filters = self.compile_filter_fields(table)
        filtered_table = table.filter(compiled_filters) if compiled_filters else table

        if self.calculated_fields:
            depth_limit = max(
                field.config.get(consts.CONFIG_DEPTH, 0)
                for field in self.calculated_fields
            )
            for n in range(0, (depth_limit + 1)):
                filtered_table = filtered_table.mutate(
                    self.compile_calculated_fields(filtered_table, n)
                )

        if (
            validation_type == consts.ROW_VALIDATION
            or validation_type == consts.CUSTOM_QUERY
        ):
            if self.comparison_fields:
                filtered_table = filtered_table.projection(
                    self.compile_comparison_fields(filtered_table)
                )
        else:
            if self.comparison_fields:
                filtered_table = filtered_table.mutate(
                    self.compile_comparison_fields(filtered_table)
                )
        compiled_groups = self.compile_group_fields(filtered_table)
        grouped_table = (
            filtered_table.group_by(compiled_groups)
            if compiled_groups
            else filtered_table
        )
        if self.aggregate_fields:
            query = grouped_table.aggregate(
                self.compile_aggregate_fields(filtered_table)
            )
        else:
            query = grouped_table

        if self.limit:
            query = query.limit(self.limit)

        return query

    def add_aggregate_field(self, aggregate_field):
        """Add an AggregateField instance to the query which
            will be used when compiling your query (ie. SUM(a))

        Args:
            aggregate_field (AggregateField): An AggregateField instance
        """
        self.aggregate_fields.append(aggregate_field)

    def add_comparison_field(self, comparison_field):
        """Add an ComparisonField instance to the query which
            will be used when compiling your query (ie. SUM(a))

        Args:
            comparison_field (ComparisonField): An ComparisonField instance
        """
        self.comparison_fields.append(comparison_field)

    def add_grouped_field(self, grouped_field):
        """Add a GroupedField instance to the query which
            represents adding a column to group by in the
            query being built.
        Args:
            grouped_field (GroupedField): A GroupedField instance
        """
        self.grouped_fields.append(grouped_field)

    def add_filter_field(self, filter_obj):
        """Add a FilterField instance to your query which
            will add the desired filter to your compiled
            query (ie. WHERE query_filter=True)
        Args:
            filter_obj (FilterField): A FilterField instance
        """
        self.filters.append(filter_obj)

    def add_calculated_field(self, calculated_field):
        """Add a CalculatedField instance to your query which
            will add the desired scalar function to your compiled
            query (ie. CONCAT(field_a, field_b))
        Args:
            calculated_field (CalculatedField): A CalculatedField instance
        """
        self.calculated_fields.append(calculated_field)
