elastic_transport/_node/_base.py (245 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import logging
import os
import ssl
from typing import Any, ClassVar, List, NamedTuple, Optional, Tuple, Union
from .._models import ApiResponseMeta, HttpHeaders, NodeConfig
from .._utils import is_ipaddress
from .._version import __version__
from ..client_utils import DEFAULT, DefaultType
_logger = logging.getLogger("elastic_transport.node")
_logger.propagate = False # This logger is very verbose so disable propogation.
DEFAULT_CA_CERTS: Optional[str] = None
DEFAULT_USER_AGENT = f"elastic-transport-python/{__version__}"
RERAISE_EXCEPTIONS = (RecursionError, asyncio.CancelledError)
BUILTIN_EXCEPTIONS = (
ValueError,
KeyError,
NameError,
AttributeError,
LookupError,
AssertionError,
IndexError,
MemoryError,
RuntimeError,
SystemError,
TypeError,
)
HTTP_STATUS_REASONS = {
200: "OK",
201: "Created",
202: "Accepted",
204: "No Content",
205: "Reset Content",
206: "Partial Content",
400: "Bad Request",
401: "Unauthorized",
402: "Payment Required",
403: "Forbidden",
404: "Not Found",
405: "Method Not Allowed",
406: "Not Acceptable",
407: "Proxy Authentication Required",
408: "Request Timeout",
409: "Conflict",
410: "Gone",
411: "Length Required",
412: "Precondition Failed",
413: "Content Too Large",
414: "URI Too Long",
415: "Unsupported Media Type",
429: "Too Many Requests",
500: "Internal Server Error",
501: "Not Implemented",
502: "Bad Gateway",
503: "Service Unavailable",
504: "Gateway Timeout",
}
try:
import certifi
DEFAULT_CA_CERTS = certifi.where()
except ImportError: # pragma: nocover
pass
class NodeApiResponse(NamedTuple):
meta: ApiResponseMeta
body: bytes
class BaseNode:
"""
Class responsible for maintaining a connection to a node. It
holds persistent node pool to it and it's main interface
(``perform_request``) is thread-safe.
:arg config: :class:`~elastic_transport.NodeConfig` instance
"""
_CLIENT_META_HTTP_CLIENT: ClassVar[Tuple[str, str]]
def __init__(self, config: NodeConfig):
self._config = config
self._headers: HttpHeaders = self.config.headers.copy() # type: ignore[attr-defined]
self.headers.setdefault("connection", "keep-alive")
self.headers.setdefault("user-agent", DEFAULT_USER_AGENT)
self._http_compress = bool(config.http_compress or False)
if config.http_compress:
self.headers["accept-encoding"] = "gzip"
self._scheme = config.scheme
self._host = config.host
self._port = config.port
self._path_prefix = (
("/" + config.path_prefix.strip("/")) if config.path_prefix else ""
)
@property
def config(self) -> NodeConfig:
return self._config
@property
def headers(self) -> HttpHeaders:
return self._headers
@property
def scheme(self) -> str:
return self._scheme
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
@property
def path_prefix(self) -> str:
return self._path_prefix
def __repr__(self) -> str:
return f"<{self.__class__.__name__}({self.base_url})>"
def __lt__(self, other: object) -> bool:
if not isinstance(other, BaseNode):
return NotImplemented
return id(self) < id(other)
def __eq__(self, other: object) -> bool:
if not isinstance(other, BaseNode):
return NotImplemented
return self.__hash__() == other.__hash__()
def __ne__(self, other: object) -> bool:
if not isinstance(other, BaseNode):
return NotImplemented
return not self == other
def __hash__(self) -> int:
return hash((str(type(self).__name__), self.config))
@property
def base_url(self) -> str:
return "".join(
[
self.scheme,
"://",
# IPv6 must be wrapped by [...]
"[%s]" % self.host if ":" in self.host else self.host,
":%s" % self.port if self.port is not None else "",
self.path_prefix,
]
)
def perform_request(
self,
method: str,
target: str,
body: Optional[bytes] = None,
headers: Optional[HttpHeaders] = None,
request_timeout: Union[DefaultType, Optional[float]] = DEFAULT,
) -> NodeApiResponse: # pragma: nocover
"""Constructs and sends an HTTP request and parses the HTTP response.
:param method: HTTP method
:param target: HTTP request target, typically path+query
:param body: Optional HTTP request body encoded as bytes
:param headers: Optional HTTP headers to send in addition to
the headers already configured.
:param request_timeout: Amount of time to wait for the first
response bytes to arrive before raising a
:class:`elastic_transport.ConnectionTimeout` error.
:raises:
:class:`elastic_transport.ConnectionError`,
:class:`elastic_transport.ConnectionTimeout`,
:class:`elastic_transport.TlsError`
:rtype: Tuple[ApiResponseMeta, bytes]
:returns: Metadata about the request+response and the raw
decompressed bytes from the HTTP response body.
"""
raise NotImplementedError()
def close(self) -> None: # pragma: nocover
pass
def _log_request(
self,
method: str,
target: str,
headers: Optional[HttpHeaders],
body: Optional[bytes],
meta: Optional[ApiResponseMeta] = None,
response: Optional[bytes] = None,
exception: Optional[Exception] = None,
) -> None:
if _logger.hasHandlers():
http_version = meta.http_version if meta else "?.?"
lines = ["> %s %s HTTP/%s"]
log_args: List[Any] = [method, target, http_version]
if headers:
for header, value in sorted(headers._dict_hide_auth().items()):
lines.append(f"> {header.title()}: {value}")
if body is not None:
try:
body_encoded = body.decode("utf-8", "surrogatepass")
except UnicodeError:
body_encoded = repr(body)
log_args.append(body_encoded)
lines.append("> %s")
if meta is not None:
reason = HTTP_STATUS_REASONS.get(meta.status, None)
if reason:
lines.append("< HTTP/%s %d %s")
log_args.extend((http_version, meta.status, reason))
else:
lines.append("< HTTP/%s %d")
log_args.extend((http_version, meta.status))
if meta.headers:
for header, value in sorted(meta.headers.items()):
lines.append(f"< {header.title()}: {value}")
if response:
try:
response_decoded = response.decode("utf-8", "surrogatepass")
except UnicodeError:
response_decoded = repr(response)
log_args.append(response_decoded)
lines.append("< %s")
if exception is not None:
_logger.debug("\n".join(lines), *log_args, exc_info=exception)
else:
_logger.debug("\n".join(lines), *log_args)
_HAS_TLS_VERSION = hasattr(ssl, "TLSVersion")
_SSL_PROTOCOL_VERSION_ATTRS = ("TLSv1", "TLSv1_1", "TLSv1_2")
_SSL_PROTOCOL_VERSION_DEFAULT = getattr(ssl, "OP_NO_SSLv2", 0) | getattr(
ssl, "OP_NO_SSLv3", 0
)
_SSL_PROTOCOL_VERSION_TO_OPTIONS = {}
_SSL_PROTOCOL_VERSION_TO_TLS_VERSION = {}
for i, _protocol_attr in enumerate(_SSL_PROTOCOL_VERSION_ATTRS):
try:
_protocol_value = getattr(ssl, f"PROTOCOL_{_protocol_attr}")
except AttributeError:
continue
if _HAS_TLS_VERSION:
_tls_version_value = getattr(ssl.TLSVersion, _protocol_attr)
_SSL_PROTOCOL_VERSION_TO_TLS_VERSION[_protocol_value] = _tls_version_value
_SSL_PROTOCOL_VERSION_TO_TLS_VERSION[_tls_version_value] = _tls_version_value
# Because we're setting a minimum version we binary OR all the options together.
_SSL_PROTOCOL_VERSION_TO_OPTIONS[_protocol_value] = (
_SSL_PROTOCOL_VERSION_DEFAULT
| sum(
getattr(ssl, f"OP_NO_{_attr}", 0)
for _attr in _SSL_PROTOCOL_VERSION_ATTRS[:i]
)
)
# TLSv1.3 is unique, doesn't have a PROTOCOL_TLSvX counterpart. So we have to set it manually.
if _HAS_TLS_VERSION:
try:
_SSL_PROTOCOL_VERSION_TO_TLS_VERSION[ssl.TLSVersion.TLSv1_3] = (
ssl.TLSVersion.TLSv1_3
)
except AttributeError: # pragma: nocover
pass
def ssl_context_from_node_config(node_config: NodeConfig) -> ssl.SSLContext:
if node_config.ssl_context:
ctx = node_config.ssl_context
else:
ctx = ssl.create_default_context()
# Enable/disable certificate verification in these orders
# to avoid 'ValueErrors' from SSLContext. We only do this
# step if the user doesn't pass a preconfigured SSLContext.
if node_config.verify_certs:
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = not is_ipaddress(node_config.host)
else:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
# Enable logging of TLS session keys for use with Wireshark.
if hasattr(ctx, "keylog_filename"):
sslkeylogfile = os.environ.get("SSLKEYLOGFILE", "")
if sslkeylogfile:
ctx.keylog_filename = sslkeylogfile
# Apply the 'ssl_version' if given, otherwise default to TLSv1.2+
ssl_version = node_config.ssl_version
if ssl_version is None:
if _HAS_TLS_VERSION:
ssl_version = ssl.TLSVersion.TLSv1_2
else:
ssl_version = ssl.PROTOCOL_TLSv1_2
try:
if _HAS_TLS_VERSION:
ctx.minimum_version = _SSL_PROTOCOL_VERSION_TO_TLS_VERSION[ssl_version]
else:
ctx.options |= _SSL_PROTOCOL_VERSION_TO_OPTIONS[ssl_version]
except KeyError:
raise ValueError(
f"Unsupported value for 'ssl_version': {ssl_version!r}. Must be "
"either 'ssl.PROTOCOL_TLSvX' or 'ssl.TLSVersion.TLSvX'"
) from None
return ctx