utils/templates/field.py.tpl (391 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import base64 import collections.abc import ipaddress from copy import deepcopy from datetime import date, datetime from typing import ( TYPE_CHECKING, Any, Dict, Iterable, Iterator, Literal, Mapping, Optional, Sequence, Tuple, Type, Union, cast, ) from dateutil import parser, tz from elastic_transport.client_utils import DEFAULT, DefaultType from .exceptions import ValidationException from .query import Q from .utils import AttrDict, AttrList, DslBase from .wrappers import Range if TYPE_CHECKING: from datetime import tzinfo from ipaddress import IPv4Address, IPv6Address from _operator import _SupportsComparison from .document import InnerDoc from .document_base import InstrumentedField from .mapping_base import MappingBase from .query import Query from . import types unicode = str def construct_field( name_or_field: Union[ str, "Field", Dict[str, Any], ], **params: Any, ) -> "Field": # {"type": "text", "analyzer": "snowball"} if isinstance(name_or_field, collections.abc.Mapping): if params: raise ValueError( "construct_field() cannot accept parameters when passing in a dict." ) params = deepcopy(name_or_field) if "type" not in params: # inner object can be implicitly defined if "properties" in params: name = "object" else: raise ValueError('construct_field() needs to have a "type" key.') else: name = params.pop("type") return Field.get_dsl_class(name)(**params) # Text() if isinstance(name_or_field, Field): if params: raise ValueError( "construct_field() cannot accept parameters " "when passing in a construct_field object." ) return name_or_field # "text", analyzer="snowball" return Field.get_dsl_class(name_or_field)(**params) class Field(DslBase): _type_name = "field" _type_shortcut = staticmethod(construct_field) # all fields can be multifields _param_defs = {"fields": {"type": "field", "hash": True}} name = "" _coerce = False def __init__( self, multi: bool = False, required: bool = False, *args: Any, **kwargs: Any ): """ :arg bool multi: specifies whether field can contain array of values :arg bool required: specifies whether field is required """ self._multi = multi self._required = required super().__init__(*args, **kwargs) def __getitem__(self, subfield: str) -> "Field": return cast(Field, self._params.get("fields", {})[subfield]) def _serialize(self, data: Any) -> Any: return data def _deserialize(self, data: Any) -> Any: return data def _empty(self) -> Optional[Any]: return None def empty(self) -> Optional[Any]: if self._multi: return AttrList([]) return self._empty() def serialize(self, data: Any) -> Any: if isinstance(data, (list, AttrList, tuple)): return list(map(self._serialize, cast(Iterable[Any], data))) return self._serialize(data) def deserialize(self, data: Any) -> Any: if isinstance(data, (list, AttrList, tuple)): data = [ None if d is None else self._deserialize(d) for d in cast(Iterable[Any], data) ] return data if data is None: return None return self._deserialize(data) def clean(self, data: Any) -> Any: if data is not None: data = self.deserialize(data) if data in (None, [], {}) and self._required: raise ValidationException("Value required for this field.") return data def to_dict(self) -> Dict[str, Any]: d = super().to_dict() name, value = cast(Tuple[str, Dict[str, Any]], d.popitem()) value["type"] = name return value class CustomField(Field): name = "custom" _coerce = True def to_dict(self) -> Dict[str, Any]: if isinstance(self.builtin_type, Field): return self.builtin_type.to_dict() d = super().to_dict() d["type"] = self.builtin_type return d class RangeField(Field): _coerce = True _core_field: Optional[Field] = None def _deserialize(self, data: Any) -> Range["_SupportsComparison"]: if isinstance(data, Range): return data data = {k: self._core_field.deserialize(v) for k, v in data.items()} # type: ignore[union-attr] return Range(data) def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: if data is None: return None if not isinstance(data, collections.abc.Mapping): data = data.to_dict() return {k: self._core_field.serialize(v) for k, v in data.items()} # type: ignore[union-attr] {% for k in classes %} class {{ k.name }}({{ k.parent }}): """ {% for line in k.docstring %} {{ line }} {% endfor %} {% if k.args %} {% if k.docstring %} {% endif %} {% for kwarg in k.args %} {% for line in kwarg.doc %} {{ line }} {% endfor %} {% endfor %} {% endif %} """ name = "{{ k.field }}" {% if k.coerced %} _coerce = True {% endif %} {% if k.name.endswith('Range') %} _core_field = {{ k.name[:-5] }}() {% endif %} {% if k.params %} _param_defs = { {% for param in k.params %} "{{ param.name }}": {{ param.param }}, {% endfor %} } {% endif %} def __init__( self, {% for arg in k.args %} {% if arg.positional %} {{ arg.name }}: {{ arg.type }} = DEFAULT, {% endif %} {% endfor %} *args: Any, {% for arg in k.args %} {% if not arg.positional %} {{ arg.name }}: {{ arg.type }} = DEFAULT, {% endif %} {% endfor %} **kwargs: Any ): {% for arg in k.args %} {% if not arg.positional %} if {{ arg.name }} is not DEFAULT: {% if "InstrumentedField" in arg.type %} kwargs["{{ arg.name }}"] = str({{ arg.name }}) {% else %} kwargs["{{ arg.name }}"] = {{ arg.name }} {% endif %} {% endif %} {% endfor %} {% if k.field == 'object' %} if doc_class is not DEFAULT and (properties is not DEFAULT or dynamic is not DEFAULT): raise ValidationException( "doc_class and properties/dynamic should not be provided together" ) if doc_class is not DEFAULT: self._doc_class: Type["InnerDoc"] = doc_class else: # FIXME import from .document import InnerDoc # no InnerDoc subclass, creating one instead... self._doc_class = type("InnerDoc", (InnerDoc,), {}) for name, field in (properties if properties is not DEFAULT else {}).items(): self._doc_class._doc_type.mapping.field(name, field) if "properties" in kwargs: del kwargs["properties"] if dynamic is not DEFAULT: self._doc_class._doc_type.mapping.meta("dynamic", dynamic) self._mapping: "MappingBase" = deepcopy(self._doc_class._doc_type.mapping) super().__init__(**kwargs) def __getitem__(self, name: str) -> Field: return self._mapping[name] def __contains__(self, name: str) -> bool: return name in self._mapping def _empty(self) -> "InnerDoc": return self._wrap({}) def _wrap(self, data: Dict[str, Any]) -> "InnerDoc": return self._doc_class.from_es(data, data_only=True) def empty(self) -> Union["InnerDoc", AttrList[Any]]: if self._multi: return AttrList[Any]([], self._wrap) return self._empty() def to_dict(self) -> Dict[str, Any]: d = self._mapping.to_dict() d.update(super().to_dict()) return d def _collect_fields(self) -> Iterator[Field]: return self._mapping.properties._collect_fields() def _deserialize(self, data: Any) -> "InnerDoc": # don't wrap already wrapped data if isinstance(data, self._doc_class): return data if isinstance(data, AttrDict): data = data._d_ return self._wrap(data) def _serialize( self, data: Optional[Union[Dict[str, Any], "InnerDoc"]] ) -> Optional[Dict[str, Any]]: if data is None: return None # somebody assigned raw dict to the field, we should tolerate that if isinstance(data, collections.abc.Mapping): return data return data.to_dict() def clean(self, data: Any) -> Any: data = super().clean(data) if data is None: return None if isinstance(data, (list, AttrList)): for d in cast(Iterator["InnerDoc"], data): d.full_clean() else: data.full_clean() return data def update(self, other: Any, update_only: bool = False) -> None: if not isinstance(other, Object): # not an inner/nested object, no merge possible return self._mapping.update(other._mapping, update_only) {% elif k.field == "nested" %} kwargs.setdefault("multi", True) super().__init__(*args, **kwargs) {% elif k.field == "date" %} if default_timezone is DEFAULT: self._default_timezone = None elif isinstance(default_timezone, str): self._default_timezone = tz.gettz(default_timezone) else: self._default_timezone = default_timezone super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> Union[datetime, date]: if isinstance(data, str): try: data = parser.parse(data) except Exception as e: raise ValidationException( f"Could not parse date from the value ({data!r})", e ) # we treat the yyyy-MM-dd format as a special case if hasattr(self, "format") and self.format == "yyyy-MM-dd": data = data.date() if isinstance(data, datetime): if self._default_timezone and data.tzinfo is None: data = data.replace(tzinfo=self._default_timezone) return data if isinstance(data, date): return data if isinstance(data, int): # Divide by a float to preserve milliseconds on the datetime. return datetime.utcfromtimestamp(data / 1000.0) raise ValidationException(f"Could not parse date from the value ({data!r})") {% elif k.field == "boolean" %} super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> bool: if data == "false": return False return bool(data) def clean(self, data: Any) -> Optional[bool]: if data is not None: data = self.deserialize(data) if data is None and self._required: raise ValidationException("Value required for this field.") return data # type: ignore[no-any-return] {% elif k.field == "float" %} super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> float: return float(data) {% elif k.field == "dense_vector" %} self._element_type = kwargs.get("element_type", "float") if self._element_type in ["float", "byte"]: kwargs["multi"] = True super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> Any: if self._element_type == "float": return float(data) elif self._element_type == "byte": return int(data) return data {% elif k.field == "scaled_float" %} if 'scaling_factor' not in kwargs: if len(args) > 0: kwargs['scaling_factor'] = args[0] args = args[1:] else: raise TypeError("missing required argument: 'scaling_factor'") super().__init__(*args, **kwargs) {% elif k.field == "integer" %} super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> int: return int(data) {% elif k.field == "ip" %} super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> Union["IPv4Address", "IPv6Address"]: # the ipaddress library for pypy only accepts unicode. return ipaddress.ip_address(unicode(data)) def _serialize(self, data: Any) -> Optional[str]: if data is None: return None return str(data) {% elif k.field == "binary" %} super().__init__(*args, **kwargs) def clean(self, data: str) -> str: # Binary fields are opaque, so there's not much cleaning # that can be done. return data def _deserialize(self, data: Any) -> bytes: return base64.b64decode(data) def _serialize(self, data: Any) -> Optional[str]: if data is None: return None return base64.b64encode(data).decode() {% elif k.field == "percolator" %} super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> "Query": return Q(data) # type: ignore[no-any-return] def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: if data is None: return None return data.to_dict() # type: ignore[no-any-return] {% else %} super().__init__(*args, **kwargs) {% endif %} {% endfor %}