#  Licensed to Elasticsearch B.V. under one or more contributor
#  license agreements. See the NOTICE file distributed with
#  this work for additional information regarding copyright
#  ownership. Elasticsearch B.V. licenses this file to you under
#  the Apache License, Version 2.0 (the "License"); you may
#  not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

from datetime import datetime, timedelta
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Generic,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)

from typing_extensions import Self

from .aggs import A, Agg
from .query import MatchAll, Nested, Query, Range, Terms
from .response import Response
from .utils import _R, AttrDict

if TYPE_CHECKING:
    from .document_base import DocumentBase
    from .response.aggs import BucketData
    from .search_base import SearchBase

FilterValueType = Union[str, int, float, bool]

__all__ = [
    "FacetedSearchBase",
    "HistogramFacet",
    "TermsFacet",
    "DateHistogramFacet",
    "RangeFacet",
    "NestedFacet",
]


class Facet(Generic[_R]):
    """
    A facet on faceted search. Wraps and aggregation and provides functionality
    to create a filter for selected values and return a list of facet values
    from the result of the aggregation.
    """

    agg_type: str = ""

    def __init__(
        self, metric: Optional[Agg[_R]] = None, metric_sort: str = "desc", **kwargs: Any
    ):
        self.filter_values = ()
        self._params = kwargs
        self._metric = metric
        if metric and metric_sort:
            self._params["order"] = {"metric": metric_sort}

    def get_aggregation(self) -> Agg[_R]:
        """
        Return the aggregation object.
        """
        agg: Agg[_R] = A(self.agg_type, **self._params)
        if self._metric:
            agg.metric("metric", self._metric)
        return agg

    def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]:
        """
        Construct a filter.
        """
        if not filter_values:
            return None

        f = self.get_value_filter(filter_values[0])
        for v in filter_values[1:]:
            f |= self.get_value_filter(v)
        return f

    def get_value_filter(self, filter_value: FilterValueType) -> Query:  # type: ignore[empty-body]
        """
        Construct a filter for an individual value
        """
        pass

    def is_filtered(self, key: str, filter_values: List[FilterValueType]) -> bool:
        """
        Is a filter active on the given key.
        """
        return key in filter_values

    def get_value(self, bucket: "BucketData[_R]") -> Any:
        """
        return a value representing a bucket. Its key as default.
        """
        return bucket["key"]

    def get_metric(self, bucket: "BucketData[_R]") -> int:
        """
        Return a metric, by default doc_count for a bucket.
        """
        if self._metric:
            return cast(int, bucket["metric"]["value"])
        return cast(int, bucket["doc_count"])

    def get_values(
        self, data: "BucketData[_R]", filter_values: List[FilterValueType]
    ) -> List[Tuple[Any, int, bool]]:
        """
        Turn the raw bucket data into a list of tuples containing the key,
        number of documents and a flag indicating whether this value has been
        selected or not.
        """
        out = []
        for bucket in data.buckets:
            b = cast("BucketData[_R]", bucket)
            key = self.get_value(b)
            out.append((key, self.get_metric(b), self.is_filtered(key, filter_values)))
        return out


class TermsFacet(Facet[_R]):
    agg_type = "terms"

    def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]:
        """Create a terms filter instead of bool containing term filters."""
        if filter_values:
            return Terms(self._params["field"], filter_values, _expand__to_dot=False)
        return None


class RangeFacet(Facet[_R]):
    agg_type = "range"

    def _range_to_dict(
        self, range: Tuple[Any, Tuple[Optional[int], Optional[int]]]
    ) -> Dict[str, Any]:
        key, _range = range
        out: Dict[str, Any] = {"key": key}
        if _range[0] is not None:
            out["from"] = _range[0]
        if _range[1] is not None:
            out["to"] = _range[1]
        return out

    def __init__(
        self,
        ranges: Sequence[Tuple[Any, Tuple[Optional[int], Optional[int]]]],
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self._params["ranges"] = list(map(self._range_to_dict, ranges))
        self._params["keyed"] = False
        self._ranges = dict(ranges)

    def get_value_filter(self, filter_value: FilterValueType) -> Query:
        f, t = self._ranges[filter_value]
        limits: Dict[str, Any] = {}
        if f is not None:
            limits["gte"] = f
        if t is not None:
            limits["lt"] = t

        return Range(self._params["field"], limits, _expand__to_dot=False)


class HistogramFacet(Facet[_R]):
    agg_type = "histogram"

    def get_value_filter(self, filter_value: FilterValueType) -> Range:
        return Range(
            self._params["field"],
            {
                "gte": filter_value,
                "lt": filter_value + self._params["interval"],
            },
            _expand__to_dot=False,
        )


def _date_interval_year(d: datetime) -> datetime:
    return d.replace(
        year=d.year + 1, day=(28 if d.month == 2 and d.day == 29 else d.day)
    )


def _date_interval_month(d: datetime) -> datetime:
    return (d + timedelta(days=32)).replace(day=1)


def _date_interval_week(d: datetime) -> datetime:
    return d + timedelta(days=7)


def _date_interval_day(d: datetime) -> datetime:
    return d + timedelta(days=1)


def _date_interval_hour(d: datetime) -> datetime:
    return d + timedelta(hours=1)


class DateHistogramFacet(Facet[_R]):
    agg_type = "date_histogram"

    DATE_INTERVALS = {
        "year": _date_interval_year,
        "1Y": _date_interval_year,
        "month": _date_interval_month,
        "1M": _date_interval_month,
        "week": _date_interval_week,
        "1w": _date_interval_week,
        "day": _date_interval_day,
        "1d": _date_interval_day,
        "hour": _date_interval_hour,
        "1h": _date_interval_hour,
    }

    def __init__(self, **kwargs: Any):
        kwargs.setdefault("min_doc_count", 0)
        super().__init__(**kwargs)

    def get_value(self, bucket: "BucketData[_R]") -> Any:
        if not isinstance(bucket["key"], datetime):
            # Elasticsearch returns key=None instead of 0 for date 1970-01-01,
            # so we need to set key to 0 to avoid TypeError exception
            if bucket["key"] is None:
                bucket["key"] = 0
            # Preserve milliseconds in the datetime
            return datetime.utcfromtimestamp(int(cast(int, bucket["key"])) / 1000.0)
        else:
            return bucket["key"]

    def get_value_filter(self, filter_value: Any) -> Range:
        for interval_type in ("calendar_interval", "fixed_interval"):
            if interval_type in self._params:
                break
        else:
            interval_type = "interval"

        return Range(
            self._params["field"],
            {
                "gte": filter_value,
                "lt": self.DATE_INTERVALS[self._params[interval_type]](filter_value),
            },
            _expand__to_dot=False,
        )


class NestedFacet(Facet[_R]):
    agg_type = "nested"

    def __init__(self, path: str, nested_facet: Facet[_R]):
        self._path = path
        self._inner = nested_facet
        super().__init__(path=path, aggs={"inner": nested_facet.get_aggregation()})

    def get_values(
        self, data: "BucketData[_R]", filter_values: List[FilterValueType]
    ) -> List[Tuple[Any, int, bool]]:
        return self._inner.get_values(data.inner, filter_values)

    def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]:
        inner_q = self._inner.add_filter(filter_values)
        if inner_q:
            return Nested(path=self._path, query=inner_q)
        return None


class FacetedResponse(Response[_R]):
    if TYPE_CHECKING:
        _faceted_search: "FacetedSearchBase[_R]"
        _facets: Dict[str, List[Tuple[Any, int, bool]]]

    @property
    def query_string(self) -> Optional[Union[str, Query]]:
        return self._faceted_search._query

    @property
    def facets(self) -> Dict[str, List[Tuple[Any, int, bool]]]:
        if not hasattr(self, "_facets"):
            super(AttrDict, self).__setattr__("_facets", AttrDict({}))
            for name, facet in self._faceted_search.facets.items():
                self._facets[name] = facet.get_values(
                    getattr(getattr(self.aggregations, "_filter_" + name), name),
                    self._faceted_search.filter_values.get(name, []),
                )
        return self._facets


class FacetedSearchBase(Generic[_R]):
    """
    Abstraction for creating faceted navigation searches that takes care of
    composing the queries, aggregations and filters as needed as well as
    presenting the results in an easy-to-consume fashion::

        class BlogSearch(FacetedSearch):
            index = 'blogs'
            doc_types = [Blog, Post]
            fields = ['title^5', 'category', 'description', 'body']

            facets = {
                'type': TermsFacet(field='_type'),
                'category': TermsFacet(field='category'),
                'weekly_posts': DateHistogramFacet(field='published_from', interval='week')
            }

            def search(self):
                ' Override search to add your own filters '
                s = super(BlogSearch, self).search()
                return s.filter('term', published=True)

        # when using:
        blog_search = BlogSearch("web framework", filters={"category": "python"})

        # supports pagination
        blog_search[10:20]

        response = blog_search.execute()

        # easy access to aggregation results:
        for category, hit_count, is_selected in response.facets.category:
            print(
                "Category %s has %d hits%s." % (
                    category,
                    hit_count,
                    ' and is chosen' if is_selected else ''
                )
            )

    """

    index: Optional[str] = None
    doc_types: Optional[List[Union[str, Type["DocumentBase"]]]] = None
    fields: Sequence[str] = []
    facets: Dict[str, Facet[_R]] = {}
    using = "default"

    if TYPE_CHECKING:

        def search(self) -> "SearchBase[_R]": ...

    def __init__(
        self,
        query: Optional[Union[str, Query]] = None,
        filters: Dict[str, FilterValueType] = {},
        sort: Sequence[str] = [],
    ):
        """
        :arg query: the text to search for
        :arg filters: facet values to filter
        :arg sort: sort information to be passed to :class:`~elasticsearch.dsl.Search`
        """
        self._query = query
        self._filters: Dict[str, Query] = {}
        self._sort = sort
        self.filter_values: Dict[str, List[FilterValueType]] = {}
        for name, value in filters.items():
            self.add_filter(name, value)

        self._s = self.build_search()

    def __getitem__(self, k: Union[int, slice]) -> Self:
        self._s = self._s[k]
        return self

    def add_filter(
        self, name: str, filter_values: Union[FilterValueType, List[FilterValueType]]
    ) -> None:
        """
        Add a filter for a facet.
        """
        # normalize the value into a list
        if not isinstance(filter_values, (tuple, list)):
            if filter_values is None:
                return
            filter_values = [
                filter_values,
            ]

        # remember the filter values for use in FacetedResponse
        self.filter_values[name] = filter_values

        # get the filter from the facet
        f = self.facets[name].add_filter(filter_values)
        if f is None:
            return

        self._filters[name] = f

    def query(
        self, search: "SearchBase[_R]", query: Union[str, Query]
    ) -> "SearchBase[_R]":
        """
        Add query part to ``search``.

        Override this if you wish to customize the query used.
        """
        if query:
            if self.fields:
                return search.query("multi_match", fields=self.fields, query=query)
            else:
                return search.query("multi_match", query=query)
        return search

    def aggregate(self, search: "SearchBase[_R]") -> None:
        """
        Add aggregations representing the facets selected, including potential
        filters.
        """
        for f, facet in self.facets.items():
            agg = facet.get_aggregation()
            agg_filter: Query = MatchAll()
            for field, filter in self._filters.items():
                if f == field:
                    continue
                agg_filter &= filter
            search.aggs.bucket("_filter_" + f, "filter", filter=agg_filter).bucket(
                f, agg
            )

    def filter(self, search: "SearchBase[_R]") -> "SearchBase[_R]":
        """
        Add a ``post_filter`` to the search request narrowing the results based
        on the facet filters.
        """
        if not self._filters:
            return search

        post_filter: Query = MatchAll()
        for f in self._filters.values():
            post_filter &= f
        return search.post_filter(post_filter)

    def highlight(self, search: "SearchBase[_R]") -> "SearchBase[_R]":
        """
        Add highlighting for all the fields
        """
        return search.highlight(
            *(f if "^" not in f else f.split("^", 1)[0] for f in self.fields)
        )

    def sort(self, search: "SearchBase[_R]") -> "SearchBase[_R]":
        """
        Add sorting information to the request.
        """
        if self._sort:
            search = search.sort(*self._sort)
        return search

    def params(self, **kwargs: Any) -> None:
        """
        Specify query params to be used when executing the search. All the
        keyword arguments will override the current values. See
        https://elasticsearch-py.readthedocs.io/en/latest/api/elasticsearch.html#elasticsearch.Elasticsearch.search
        for all available parameters.
        """
        self._s = self._s.params(**kwargs)

    def build_search(self) -> "SearchBase[_R]":
        """
        Construct the ``Search`` object.
        """
        s = self.search()
        if self._query is not None:
            s = self.query(s, self._query)
        s = self.filter(s)
        if self.fields:
            s = self.highlight(s)
        s = self.sort(s)
        self.aggregate(s)
        return s
