utils/templates/query.py.tpl (312 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import collections.abc
from copy import deepcopy
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
cast,
overload,
)
from elastic_transport.client_utils import DEFAULT
# 'SF' looks unused but the test suite assumes it's available
# from this module so others are liable to do so as well.
from .function import SF # noqa: F401
from .function import ScoreFunction
from .utils import DslBase
if TYPE_CHECKING:
from elastic_transport.client_utils import DefaultType
from . import types, wrappers
from .document_base import InstrumentedField
_T = TypeVar("_T")
_M = TypeVar("_M", bound=Mapping[str, Any])
class QProxiedProtocol(Protocol[_T]):
_proxied: _T
@overload
def Q(name_or_query: MutableMapping[str, _M]) -> "Query": ...
@overload
def Q(name_or_query: "Query") -> "Query": ...
@overload
def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ...
@overload
def Q(name_or_query: str = "match_all", **params: Any) -> "Query": ...
def Q(
name_or_query: Union[
str,
"Query",
QProxiedProtocol[_T],
MutableMapping[str, _M],
] = "match_all",
**params: Any,
) -> Union["Query", _T]:
# {"match": {"title": "python"}}
if isinstance(name_or_query, collections.abc.MutableMapping):
if params:
raise ValueError("Q() cannot accept parameters when passing in a dict.")
if len(name_or_query) != 1:
raise ValueError(
'Q() can only accept dict with a single query ({"match": {...}}). '
"Instead it got (%r)" % name_or_query
)
name, q_params = deepcopy(name_or_query).popitem()
return Query.get_dsl_class(name)(_expand__to_dot=False, **q_params)
# MatchAll()
if isinstance(name_or_query, Query):
if params:
raise ValueError(
"Q() cannot accept parameters when passing in a Query object."
)
return name_or_query
# s.query = Q('filtered', query=s.query)
if hasattr(name_or_query, "_proxied"):
return cast(QProxiedProtocol[_T], name_or_query)._proxied
# "match", title="python"
return Query.get_dsl_class(name_or_query)(**params)
class Query(DslBase):
_type_name = "query"
_type_shortcut = staticmethod(Q)
name: ClassVar[Optional[str]] = None
# Add type annotations for methods not defined in every subclass
__ror__: ClassVar[Callable[["Query", "Query"], "Query"]]
__radd__: ClassVar[Callable[["Query", "Query"], "Query"]]
__rand__: ClassVar[Callable[["Query", "Query"], "Query"]]
def __add__(self, other: "Query") -> "Query":
# make sure we give queries that know how to combine themselves
# preference
if hasattr(other, "__radd__"):
return other.__radd__(self)
return Bool(must=[self, other])
def __invert__(self) -> "Query":
return Bool(must_not=[self])
def __or__(self, other: "Query") -> "Query":
# make sure we give queries that know how to combine themselves
# preference
if hasattr(other, "__ror__"):
return other.__ror__(self)
return Bool(should=[self, other])
def __and__(self, other: "Query") -> "Query":
# make sure we give queries that know how to combine themselves
# preference
if hasattr(other, "__rand__"):
return other.__rand__(self)
return Bool(must=[self, other])
{% for k in classes %}
class {{ k.name }}({{ parent }}):
"""
{% for line in k.docstring %}
{{ line }}
{% endfor %}
{% if k.args %}
{% if k.docstring %}
{% endif %}
{% for kwarg in k.args %}
{% for line in kwarg.doc %}
{{ line }}
{% endfor %}
{% endfor %}
{% endif %}
"""
name = "{{ k.property_name }}"
{% if k.params %}
_param_defs = {
{% for param in k.params %}
"{{ param.name }}": {{ param.param }},
{% endfor %}
{% if k.name == "FunctionScore" %}
{# The FunctionScore class implements a custom solution for the `functions`
shortcut property. Until the code generator can support shortcut
properties directly that solution is added here #}
"filter": {"type": "query"},
{% endif %}
}
{% endif %}
def __init__(
self,
{% for arg in k.args %}
{% if arg.positional %}
{{ arg.name }}: {{ arg.type }} = DEFAULT,
{% endif %}
{% endfor %}
{% if k.args and not k.args[-1].positional %}
*,
{% endif %}
{% for arg in k.args %}
{% if not arg.positional %}
{{ arg.name }}: {{ arg.type }} = DEFAULT,
{% endif %}
{% endfor %}
**kwargs: Any
):
{% if k.name == "FunctionScore" %}
{# continuation of the FunctionScore shortcut property support from above #}
if functions is DEFAULT:
functions = []
for name in ScoreFunction._classes:
if name in kwargs:
functions.append({name: kwargs.pop(name)}) # type: ignore[arg-type]
{% elif k.is_single_field %}
if _field is not DEFAULT:
kwargs[str(_field)] = _value
{% elif k.is_multi_field %}
if _fields is not DEFAULT:
for field, value in _fields.items():
kwargs[str(field)] = value
{% endif %}
super().__init__(
{% for arg in k.args %}
{% if not arg.positional %}
{{ arg.name }}={{ arg.name }},
{% endif %}
{% endfor %}
**kwargs
)
{# what follows is a set of Pythonic enhancements to some of the query classes
which are outside the scope of the code generator #}
{% if k.name == "MatchAll" %}
def __add__(self, other: "Query") -> "Query":
return other._clone()
__and__ = __rand__ = __radd__ = __add__
def __or__(self, other: "Query") -> "MatchAll":
return self
__ror__ = __or__
def __invert__(self) -> "MatchNone":
return MatchNone()
EMPTY_QUERY = MatchAll()
{% elif k.name == "MatchNone" %}
def __add__(self, other: "Query") -> "MatchNone":
return self
__and__ = __rand__ = __radd__ = __add__
def __or__(self, other: "Query") -> "Query":
return other._clone()
__ror__ = __or__
def __invert__(self) -> MatchAll:
return MatchAll()
{% elif k.name == "Bool" %}
def __add__(self, other: Query) -> "Bool":
q = self._clone()
if isinstance(other, Bool):
q.must += other.must
q.should += other.should
q.must_not += other.must_not
q.filter += other.filter
else:
q.must.append(other)
return q
__radd__ = __add__
def __or__(self, other: Query) -> Query:
for q in (self, other):
if isinstance(q, Bool) and not any(
(q.must, q.must_not, q.filter, getattr(q, "minimum_should_match", None))
):
other = self if q is other else other
q = q._clone()
if isinstance(other, Bool) and not any(
(
other.must,
other.must_not,
other.filter,
getattr(other, "minimum_should_match", None),
)
):
q.should.extend(other.should)
else:
q.should.append(other)
return q
return Bool(should=[self, other])
__ror__ = __or__
@property
def _min_should_match(self) -> int:
return getattr(
self,
"minimum_should_match",
0 if not self.should or (self.must or self.filter) else 1,
)
def __invert__(self) -> Query:
# Because an empty Bool query is treated like
# MatchAll the inverse should be MatchNone
if not any(chain(self.must, self.filter, self.should, self.must_not)):
return MatchNone()
negations: List[Query] = []
for q in chain(self.must, self.filter):
negations.append(~q)
for q in self.must_not:
negations.append(q)
if self.should and self._min_should_match:
negations.append(Bool(must_not=self.should[:]))
if len(negations) == 1:
return negations[0]
return Bool(should=negations)
def __and__(self, other: Query) -> Query:
q = self._clone()
if isinstance(other, Bool):
q.must += other.must
q.must_not += other.must_not
q.filter += other.filter
q.should = []
# reset minimum_should_match as it will get calculated below
if "minimum_should_match" in q._params:
del q._params["minimum_should_match"]
for qx in (self, other):
min_should_match = qx._min_should_match
# TODO: percentages or negative numbers will fail here
# for now we report an error
if not isinstance(min_should_match, int) or min_should_match < 0:
raise ValueError(
"Can only combine queries with positive integer values for minimum_should_match"
)
# all subqueries are required
if len(qx.should) <= min_should_match:
q.must.extend(qx.should)
# not all of them are required, use it and remember min_should_match
elif not q.should:
q.minimum_should_match = min_should_match
q.should = qx.should
# all queries are optional, just extend should
elif q._min_should_match == 0 and min_should_match == 0:
q.should.extend(qx.should)
# not all are required, add a should list to the must with proper min_should_match
else:
q.must.append(
Bool(should=qx.should, minimum_should_match=min_should_match)
)
else:
if not (q.must or q.filter) and q.should:
q._params.setdefault("minimum_should_match", 1)
q.must.append(other)
return q
__rand__ = __and__
{% elif k.name == "Terms" %}
def _setattr(self, name: str, value: Any) -> None:
# here we convert any iterables that are not strings to lists
if hasattr(value, "__iter__") and not isinstance(value, (str, list, dict)):
value = list(value)
super()._setattr(name, value)
{% endif %}
{% endfor %}