elasticsearch/dsl/faceted_search_base.py (286 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, Any, Dict, Generic, List, Optional, Sequence, Tuple, Type, Union, cast, ) from typing_extensions import Self from .aggs import A, Agg from .query import MatchAll, Nested, Query, Range, Terms from .response import Response from .utils import _R, AttrDict if TYPE_CHECKING: from .document_base import DocumentBase from .response.aggs import BucketData from .search_base import SearchBase FilterValueType = Union[str, int, float, bool] __all__ = [ "FacetedSearchBase", "HistogramFacet", "TermsFacet", "DateHistogramFacet", "RangeFacet", "NestedFacet", ] class Facet(Generic[_R]): """ A facet on faceted search. Wraps and aggregation and provides functionality to create a filter for selected values and return a list of facet values from the result of the aggregation. """ agg_type: str = "" def __init__( self, metric: Optional[Agg[_R]] = None, metric_sort: str = "desc", **kwargs: Any ): self.filter_values = () self._params = kwargs self._metric = metric if metric and metric_sort: self._params["order"] = {"metric": metric_sort} def get_aggregation(self) -> Agg[_R]: """ Return the aggregation object. """ agg: Agg[_R] = A(self.agg_type, **self._params) if self._metric: agg.metric("metric", self._metric) return agg def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]: """ Construct a filter. """ if not filter_values: return None f = self.get_value_filter(filter_values[0]) for v in filter_values[1:]: f |= self.get_value_filter(v) return f def get_value_filter(self, filter_value: FilterValueType) -> Query: # type: ignore[empty-body] """ Construct a filter for an individual value """ pass def is_filtered(self, key: str, filter_values: List[FilterValueType]) -> bool: """ Is a filter active on the given key. """ return key in filter_values def get_value(self, bucket: "BucketData[_R]") -> Any: """ return a value representing a bucket. Its key as default. """ return bucket["key"] def get_metric(self, bucket: "BucketData[_R]") -> int: """ Return a metric, by default doc_count for a bucket. """ if self._metric: return cast(int, bucket["metric"]["value"]) return cast(int, bucket["doc_count"]) def get_values( self, data: "BucketData[_R]", filter_values: List[FilterValueType] ) -> List[Tuple[Any, int, bool]]: """ Turn the raw bucket data into a list of tuples containing the key, number of documents and a flag indicating whether this value has been selected or not. """ out = [] for bucket in data.buckets: b = cast("BucketData[_R]", bucket) key = self.get_value(b) out.append((key, self.get_metric(b), self.is_filtered(key, filter_values))) return out class TermsFacet(Facet[_R]): agg_type = "terms" def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]: """Create a terms filter instead of bool containing term filters.""" if filter_values: return Terms(self._params["field"], filter_values, _expand__to_dot=False) return None class RangeFacet(Facet[_R]): agg_type = "range" def _range_to_dict( self, range: Tuple[Any, Tuple[Optional[int], Optional[int]]] ) -> Dict[str, Any]: key, _range = range out: Dict[str, Any] = {"key": key} if _range[0] is not None: out["from"] = _range[0] if _range[1] is not None: out["to"] = _range[1] return out def __init__( self, ranges: Sequence[Tuple[Any, Tuple[Optional[int], Optional[int]]]], **kwargs: Any, ): super().__init__(**kwargs) self._params["ranges"] = list(map(self._range_to_dict, ranges)) self._params["keyed"] = False self._ranges = dict(ranges) def get_value_filter(self, filter_value: FilterValueType) -> Query: f, t = self._ranges[filter_value] limits: Dict[str, Any] = {} if f is not None: limits["gte"] = f if t is not None: limits["lt"] = t return Range(self._params["field"], limits, _expand__to_dot=False) class HistogramFacet(Facet[_R]): agg_type = "histogram" def get_value_filter(self, filter_value: FilterValueType) -> Range: return Range( self._params["field"], { "gte": filter_value, "lt": filter_value + self._params["interval"], }, _expand__to_dot=False, ) def _date_interval_year(d: datetime) -> datetime: return d.replace( year=d.year + 1, day=(28 if d.month == 2 and d.day == 29 else d.day) ) def _date_interval_month(d: datetime) -> datetime: return (d + timedelta(days=32)).replace(day=1) def _date_interval_week(d: datetime) -> datetime: return d + timedelta(days=7) def _date_interval_day(d: datetime) -> datetime: return d + timedelta(days=1) def _date_interval_hour(d: datetime) -> datetime: return d + timedelta(hours=1) class DateHistogramFacet(Facet[_R]): agg_type = "date_histogram" DATE_INTERVALS = { "year": _date_interval_year, "1Y": _date_interval_year, "month": _date_interval_month, "1M": _date_interval_month, "week": _date_interval_week, "1w": _date_interval_week, "day": _date_interval_day, "1d": _date_interval_day, "hour": _date_interval_hour, "1h": _date_interval_hour, } def __init__(self, **kwargs: Any): kwargs.setdefault("min_doc_count", 0) super().__init__(**kwargs) def get_value(self, bucket: "BucketData[_R]") -> Any: if not isinstance(bucket["key"], datetime): # Elasticsearch returns key=None instead of 0 for date 1970-01-01, # so we need to set key to 0 to avoid TypeError exception if bucket["key"] is None: bucket["key"] = 0 # Preserve milliseconds in the datetime return datetime.utcfromtimestamp(int(cast(int, bucket["key"])) / 1000.0) else: return bucket["key"] def get_value_filter(self, filter_value: Any) -> Range: for interval_type in ("calendar_interval", "fixed_interval"): if interval_type in self._params: break else: interval_type = "interval" return Range( self._params["field"], { "gte": filter_value, "lt": self.DATE_INTERVALS[self._params[interval_type]](filter_value), }, _expand__to_dot=False, ) class NestedFacet(Facet[_R]): agg_type = "nested" def __init__(self, path: str, nested_facet: Facet[_R]): self._path = path self._inner = nested_facet super().__init__(path=path, aggs={"inner": nested_facet.get_aggregation()}) def get_values( self, data: "BucketData[_R]", filter_values: List[FilterValueType] ) -> List[Tuple[Any, int, bool]]: return self._inner.get_values(data.inner, filter_values) def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]: inner_q = self._inner.add_filter(filter_values) if inner_q: return Nested(path=self._path, query=inner_q) return None class FacetedResponse(Response[_R]): if TYPE_CHECKING: _faceted_search: "FacetedSearchBase[_R]" _facets: Dict[str, List[Tuple[Any, int, bool]]] @property def query_string(self) -> Optional[Union[str, Query]]: return self._faceted_search._query @property def facets(self) -> Dict[str, List[Tuple[Any, int, bool]]]: if not hasattr(self, "_facets"): super(AttrDict, self).__setattr__("_facets", AttrDict({})) for name, facet in self._faceted_search.facets.items(): self._facets[name] = facet.get_values( getattr(getattr(self.aggregations, "_filter_" + name), name), self._faceted_search.filter_values.get(name, []), ) return self._facets class FacetedSearchBase(Generic[_R]): """ Abstraction for creating faceted navigation searches that takes care of composing the queries, aggregations and filters as needed as well as presenting the results in an easy-to-consume fashion:: class BlogSearch(FacetedSearch): index = 'blogs' doc_types = [Blog, Post] fields = ['title^5', 'category', 'description', 'body'] facets = { 'type': TermsFacet(field='_type'), 'category': TermsFacet(field='category'), 'weekly_posts': DateHistogramFacet(field='published_from', interval='week') } def search(self): ' Override search to add your own filters ' s = super(BlogSearch, self).search() return s.filter('term', published=True) # when using: blog_search = BlogSearch("web framework", filters={"category": "python"}) # supports pagination blog_search[10:20] response = blog_search.execute() # easy access to aggregation results: for category, hit_count, is_selected in response.facets.category: print( "Category %s has %d hits%s." % ( category, hit_count, ' and is chosen' if is_selected else '' ) ) """ index: Optional[str] = None doc_types: Optional[List[Union[str, Type["DocumentBase"]]]] = None fields: Sequence[str] = [] facets: Dict[str, Facet[_R]] = {} using = "default" if TYPE_CHECKING: def search(self) -> "SearchBase[_R]": ... def __init__( self, query: Optional[Union[str, Query]] = None, filters: Dict[str, FilterValueType] = {}, sort: Sequence[str] = [], ): """ :arg query: the text to search for :arg filters: facet values to filter :arg sort: sort information to be passed to :class:`~elasticsearch.dsl.Search` """ self._query = query self._filters: Dict[str, Query] = {} self._sort = sort self.filter_values: Dict[str, List[FilterValueType]] = {} for name, value in filters.items(): self.add_filter(name, value) self._s = self.build_search() def __getitem__(self, k: Union[int, slice]) -> Self: self._s = self._s[k] return self def add_filter( self, name: str, filter_values: Union[FilterValueType, List[FilterValueType]] ) -> None: """ Add a filter for a facet. """ # normalize the value into a list if not isinstance(filter_values, (tuple, list)): if filter_values is None: return filter_values = [ filter_values, ] # remember the filter values for use in FacetedResponse self.filter_values[name] = filter_values # get the filter from the facet f = self.facets[name].add_filter(filter_values) if f is None: return self._filters[name] = f def query( self, search: "SearchBase[_R]", query: Union[str, Query] ) -> "SearchBase[_R]": """ Add query part to ``search``. Override this if you wish to customize the query used. """ if query: if self.fields: return search.query("multi_match", fields=self.fields, query=query) else: return search.query("multi_match", query=query) return search def aggregate(self, search: "SearchBase[_R]") -> None: """ Add aggregations representing the facets selected, including potential filters. """ for f, facet in self.facets.items(): agg = facet.get_aggregation() agg_filter: Query = MatchAll() for field, filter in self._filters.items(): if f == field: continue agg_filter &= filter search.aggs.bucket("_filter_" + f, "filter", filter=agg_filter).bucket( f, agg ) def filter(self, search: "SearchBase[_R]") -> "SearchBase[_R]": """ Add a ``post_filter`` to the search request narrowing the results based on the facet filters. """ if not self._filters: return search post_filter: Query = MatchAll() for f in self._filters.values(): post_filter &= f return search.post_filter(post_filter) def highlight(self, search: "SearchBase[_R]") -> "SearchBase[_R]": """ Add highlighting for all the fields """ return search.highlight( *(f if "^" not in f else f.split("^", 1)[0] for f in self.fields) ) def sort(self, search: "SearchBase[_R]") -> "SearchBase[_R]": """ Add sorting information to the request. """ if self._sort: search = search.sort(*self._sort) return search def params(self, **kwargs: Any) -> None: """ Specify query params to be used when executing the search. All the keyword arguments will override the current values. See https://elasticsearch-py.readthedocs.io/en/latest/api/elasticsearch.html#elasticsearch.Elasticsearch.search for all available parameters. """ self._s = self._s.params(**kwargs) def build_search(self) -> "SearchBase[_R]": """ Construct the ``Search`` object. """ s = self.search() if self._query is not None: s = self.query(s, self._query) s = self.filter(s) if self.fields: s = self.highlight(s) s = self.sort(s) self.aggregate(s) return s