elastic_transport/_models.py (221 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import dataclasses import enum import re import ssl from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, Collection, Dict, Iterator, KeysView, Mapping, MutableMapping, Optional, Tuple, TypeVar, Union, ValuesView, ) if TYPE_CHECKING: from typing import Final class DefaultType(enum.Enum): """ Sentinel used as a default value when ``None`` has special meaning like timeouts. The only comparisons that are supported for this type are ``is``. """ value = 0 def __repr__(self) -> str: return "<DEFAULT>" def __str__(self) -> str: return "<DEFAULT>" DEFAULT: "Final[DefaultType]" = DefaultType.value T = TypeVar("T") _TYPE_SSL_VERSION = Union[int, ssl.TLSVersion] class HttpHeaders(MutableMapping[str, str]): """HTTP headers Behaves like a Python dictionary. Can be used like this:: headers = HttpHeaders() headers["foo"] = "bar" headers["foo"] = "baz" print(headers["foo"]) # prints "baz" """ __slots__ = ("_internal", "_frozen") def __init__( self, initial: Optional[Union[Mapping[str, str], Collection[Tuple[str, str]]]] = None, ) -> None: self._internal = {} self._frozen = False if initial: for key, val in dict(initial).items(): self._internal[self._normalize_key(key)] = (key, val) def __setitem__(self, key: str, value: str) -> None: if self._frozen: raise ValueError("Can't modify headers that have been frozen") self._internal[self._normalize_key(key)] = (key, value) def __getitem__(self, item: str) -> str: return self._internal[self._normalize_key(item)][1] def __delitem__(self, key: str) -> None: if self._frozen: raise ValueError("Can't modify headers that have been frozen") del self._internal[self._normalize_key(key)] def __eq__(self, other: object) -> bool: if not isinstance(other, Mapping): return NotImplemented if not isinstance(other, HttpHeaders): other = HttpHeaders(other) return {k: v for k, (_, v) in self._internal.items()} == { k: v for k, (_, v) in other._internal.items() } def __ne__(self, other: object) -> bool: if not isinstance(other, Mapping): return NotImplemented return not self == other def __iter__(self) -> Iterator[str]: return iter(self.keys()) def __len__(self) -> int: return len(self._internal) def __bool__(self) -> bool: return bool(self._internal) def __contains__(self, item: object) -> bool: return isinstance(item, str) and self._normalize_key(item) in self._internal def __repr__(self) -> str: return repr(self._dict_hide_auth()) def __str__(self) -> str: return str(self._dict_hide_auth()) def __hash__(self) -> int: if not self._frozen: raise ValueError("Can't calculate the hash of headers that aren't frozen") return hash(tuple((k, v) for k, (_, v) in sorted(self._internal.items()))) def get(self, key: str, default: Optional[str] = None) -> Optional[str]: # type: ignore[override] return self._internal.get(self._normalize_key(key), (None, default))[1] def keys(self) -> KeysView[str]: return self._internal.keys() def values(self) -> ValuesView[str]: return {"": v for _, v in self._internal.values()}.values() def items(self) -> Collection[Tuple[str, str]]: # type: ignore[override] return [(key, val) for _, (key, val) in self._internal.items()] def freeze(self) -> "HttpHeaders": """Freezes the current set of headers so they can be used in hashes. Returns the same instance, doesn't make a copy. """ self._frozen = True return self @property def frozen(self) -> bool: return self._frozen def copy(self) -> "HttpHeaders": return HttpHeaders(self.items()) def _normalize_key(self, key: str) -> str: try: return key.lower() except AttributeError: return key def _dict_hide_auth(self) -> Dict[str, str]: def hide_auth(val: str) -> str: # Hides only the authentication value, not the method. match = re.match(r"^(ApiKey|Basic|Bearer) ", val) if match: return f"{match.group(1)} <hidden>" return "<hidden>" return { key: hide_auth(val) if key.lower() == "authorization" else val for key, val in self.items() } @dataclass class ApiResponseMeta: """Metadata that is returned from Transport.perform_request() :ivar int status: HTTP status code :ivar str http_version: HTTP version being used :ivar HttpHeaders headers: HTTP headers :ivar float duration: Number of seconds from start of request to start of response :ivar NodeConfig node: Node which handled the request :ivar typing.Optional[str] mimetype: Mimetype to be used by the serializer to decode the raw response bytes. """ status: int http_version: str headers: HttpHeaders duration: float node: "NodeConfig" @property def mimetype(self) -> Optional[str]: try: content_type = self.headers["content-type"] return content_type.partition(";")[0] or None except KeyError: return None def _empty_frozen_http_headers() -> HttpHeaders: """Used for the 'default_factory' of the 'NodeConfig.headers'""" return HttpHeaders().freeze() @dataclass(repr=True) class NodeConfig: """Configuration options available for every node.""" #: Protocol in use to connect to the node scheme: str #: IP address or hostname to connect to host: str #: IP port to connect to port: int #: Prefix to add to the path of every request path_prefix: str = "" #: Default HTTP headers to add to every request headers: Union[HttpHeaders, Mapping[str, str]] = field( default_factory=_empty_frozen_http_headers ) #: Number of concurrent connections that are #: able to be open at one time for this node. #: Having multiple connections per node allows #: for higher concurrency of requests. connections_per_node: int = 10 #: Number of seconds to wait before a request should timeout. request_timeout: Optional[float] = 10.0 #: Set to ``True`` to enable HTTP compression #: of request and response bodies via gzip. http_compress: Optional[bool] = False #: Set to ``True`` to verify the node's TLS certificate against 'ca_certs' #: Setting to ``False`` will disable verifying the node's certificate. verify_certs: Optional[bool] = True #: Path to a CA bundle or directory containing bundles. By default #: If the ``certifi`` package is installed and ``verify_certs`` is #: set to ``True`` this value will be set to ``certifi.where()``. ca_certs: Optional[str] = None #: Path to a client certificate for TLS client authentication. client_cert: Optional[str] = None #: Path to a client private key for TLS client authentication. client_key: Optional[str] = None #: Hostname or IP address to verify on the node's certificate. #: This is useful if the certificate contains a different value #: than the one supplied in ``host``. An example of this situation #: is connecting to an IP address instead of a hostname. #: Set to ``False`` to disable certificate hostname verification. ssl_assert_hostname: Optional[str] = None #: SHA-256 fingerprint of the node's certificate. If this value is #: given then root-of-trust verification isn't done and only the #: node's certificate fingerprint is verified. #: #: On CPython 3.10+ this also verifies if any certificate in the #: chain including the Root CA matches this fingerprint. However #: because this requires using private APIs support for this is #: **experimental**. ssl_assert_fingerprint: Optional[str] = None #: Minimum TLS version to use to connect to the node. Can be either #: :class:`ssl.TLSVersion` or one of the deprecated #: ``ssl.PROTOCOL_TLSvX`` instances. ssl_version: Optional[_TYPE_SSL_VERSION] = None #: Pre-configured :class:`ssl.SSLContext` object. If this value #: is given then no other TLS options (besides ``ssl_assert_fingerprint``) #: can be set on the :class:`elastic_transport.NodeConfig`. ssl_context: Optional[ssl.SSLContext] = field(default=None, hash=False) #: Set to ``False`` to disable the :class:`elastic_transport.SecurityWarning` #: issued when using ``verify_certs=False``. ssl_show_warn: bool = True #: Extras that can be set to anything, typically used #: for annotating this node with additional information for #: future decisions like sniffing, instance roles, etc. #: Third-party keys should start with an underscore and prefix. _extras: Dict[str, Any] = field(default_factory=dict, hash=False) def replace(self, **kwargs: Any) -> "NodeConfig": if not kwargs: return self return dataclasses.replace(self, **kwargs) def __post_init__(self) -> None: if not isinstance(self.headers, HttpHeaders) or not self.headers.frozen: self.headers = HttpHeaders(self.headers).freeze() if self.scheme != self.scheme.lower(): raise ValueError("'scheme' must be lowercase") if "[" in self.host or "]" in self.host: raise ValueError("'host' must not have square braces") if self.port < 0: raise ValueError("'port' must be a positive integer") if self.connections_per_node <= 0: raise ValueError("'connections_per_node' must be a positive integer") if self.path_prefix: self.path_prefix = ( ("/" + self.path_prefix.strip("/")) if self.path_prefix else "" ) tls_options = [ "ca_certs", "client_cert", "client_key", "ssl_assert_hostname", "ssl_assert_fingerprint", "ssl_context", ] # Disallow setting TLS options on non-HTTPS connections. if self.scheme != "https": if any(getattr(self, attr) is not None for attr in tls_options): raise ValueError("TLS options require scheme to be 'https'") elif self.scheme == "https": # It's not valid to set 'ssl_context' and any other # TLS option, the SSLContext object must be configured # the way the user wants already. def tls_option_filter(attr: object) -> bool: return ( isinstance(attr, str) and attr not in ("ssl_context", "ssl_assert_fingerprint") and getattr(self, attr) is not None ) if self.ssl_context is not None and any( filter( tls_option_filter, tls_options, ) ): raise ValueError( "The 'ssl_context' option can't be combined with other TLS options" ) def __eq__(self, other: object) -> bool: if not isinstance(other, NodeConfig): return NotImplemented return ( self.scheme == other.scheme and self.host == other.host and self.port == other.port and self.path_prefix == other.path_prefix ) def __ne__(self, other: object) -> bool: if not isinstance(other, NodeConfig): return NotImplemented return not self == other def __hash__(self) -> int: return hash( ( self.scheme, self.host, self.port, self.path_prefix, ) ) @dataclass() class SniffOptions: """Options which are passed to Transport.sniff_callback""" is_initial_sniff: bool sniff_timeout: Optional[float]