elastic_transport/_transport.py (372 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import dataclasses
import inspect
import logging
import time
import warnings
from platform import python_version
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Mapping,
NamedTuple,
Optional,
Tuple,
Type,
Union,
cast,
)
from ._compat import Lock, warn_stacklevel
from ._exceptions import (
ConnectionError,
ConnectionTimeout,
SniffingError,
TransportError,
TransportWarning,
)
from ._models import (
DEFAULT,
ApiResponseMeta,
DefaultType,
HttpHeaders,
NodeConfig,
SniffOptions,
)
from ._node import (
AiohttpHttpNode,
BaseNode,
HttpxAsyncHttpNode,
RequestsHttpNode,
Urllib3HttpNode,
)
from ._node_pool import NodePool, NodeSelector
from ._otel import OpenTelemetrySpan
from ._serializer import DEFAULT_SERIALIZERS, Serializer, SerializerCollection
from ._version import __version__
from .client_utils import client_meta_version, resolve_default
# Allows for using a node_class by name rather than import.
NODE_CLASS_NAMES: Dict[str, Type[BaseNode]] = {
"urllib3": Urllib3HttpNode,
"requests": RequestsHttpNode,
"aiohttp": AiohttpHttpNode,
"httpxasync": HttpxAsyncHttpNode,
}
# These are HTTP status errors that shouldn't be considered
# 'errors' for marking a node as dead. These errors typically
# mean everything is fine server-wise and instead the API call
# in question responded successfully.
NOT_DEAD_NODE_HTTP_STATUSES = {None, 400, 401, 402, 403, 404, 409}
DEFAULT_CLIENT_META_SERVICE = ("et", client_meta_version(__version__))
_logger = logging.getLogger("elastic_transport.transport")
class TransportApiResponse(NamedTuple):
meta: ApiResponseMeta
body: Any
class Transport:
"""
Encapsulation of transport-related to logic. Handles instantiation of the
individual nodes as well as creating a node pool to hold them.
Main interface is the :meth:`elastic_transport.Transport.perform_request` method.
"""
def __init__(
self,
node_configs: List[NodeConfig],
node_class: Union[str, Type[BaseNode]] = Urllib3HttpNode,
node_pool_class: Type[NodePool] = NodePool,
randomize_nodes_in_pool: bool = True,
node_selector_class: Optional[Union[str, Type[NodeSelector]]] = None,
dead_node_backoff_factor: Optional[float] = None,
max_dead_node_backoff: Optional[float] = None,
serializers: Optional[Mapping[str, Serializer]] = None,
default_mimetype: str = "application/json",
max_retries: int = 3,
retry_on_status: Collection[int] = (429, 502, 503, 504),
retry_on_timeout: bool = False,
sniff_on_start: bool = False,
sniff_before_requests: bool = False,
sniff_on_node_failure: bool = False,
sniff_timeout: Optional[float] = 0.5,
min_delay_between_sniffing: float = 10.0,
sniff_callback: Optional[
Callable[
["Transport", "SniffOptions"],
Union[List[NodeConfig], List[NodeConfig]],
]
] = None,
meta_header: bool = True,
client_meta_service: Tuple[str, str] = DEFAULT_CLIENT_META_SERVICE,
):
"""
:arg node_configs: List of 'NodeConfig' instances to create initial set of nodes.
:arg node_class: subclass of :class:`~elastic_transport.BaseNode` to use
or the name of the Connection (ie 'urllib3', 'requests')
:arg node_pool_class: subclass of :class:`~elastic_transport.NodePool` to use
:arg randomize_nodes_in_pool: Set to false to not randomize nodes within the pool.
Defaults to true.
:arg node_selector_class: Class to be used to select nodes within
the :class:`~elastic_transport.NodePool`.
:arg dead_node_backoff_factor: Exponential backoff factor to calculate the amount
of time to timeout a node after an unsuccessful API call.
:arg max_dead_node_backoff: Maximum amount of time to timeout a node after an
unsuccessful API call.
:arg serializers: optional dict of serializer instances that will be
used for deserializing data coming from the server. (key is the mimetype)
:arg max_retries: Maximum number of retries for an API call.
Set to 0 to disable retries. Defaults to ``0``.
:arg retry_on_status: set of HTTP status codes on which we should retry
on a different node. defaults to ``(429, 502, 503, 504)``
:arg retry_on_timeout: should timeout trigger a retry on different
node? (default ``False``)
:arg sniff_on_start: If ``True`` will sniff for additional nodes as soon
as possible, guaranteed before the first request.
:arg sniff_on_node_failure: If ``True`` will sniff for additional nodees
after a node is marked as dead in the pool.
:arg sniff_before_requests: If ``True`` will occasionally sniff for additional
nodes as requests are sent.
:arg sniff_timeout: Timeout value in seconds to use for sniffing requests.
Defaults to 1 second.
:arg min_delay_between_sniffing: Number of seconds to wait between calls to
:meth:`elastic_transport.Transport.sniff` to avoid sniffing too frequently.
Defaults to 10 seconds.
:arg sniff_callback: Function that is passed a :class:`elastic_transport.Transport` and
:class:`elastic_transport.SniffOptions` and should do node discovery and
return a list of :class:`elastic_transport.NodeConfig` instances.
:arg meta_header: If set to False the ``X-Elastic-Client-Meta`` HTTP header won't be sent.
Defaults to True.
:arg client_meta_service: Key-value pair for the service field of the client metadata header.
Defaults to the service key-value for Elastic Transport.
"""
if isinstance(node_class, str):
if node_class not in NODE_CLASS_NAMES:
options = "', '".join(sorted(NODE_CLASS_NAMES.keys()))
raise ValueError(
f"Unknown option for node_class: '{node_class}'. "
f"Available options are: '{options}'"
)
node_class = NODE_CLASS_NAMES[node_class]
# Verify that the node_class we're passed is
# async/sync the same as the transport is.
is_transport_async = inspect.iscoroutinefunction(self.perform_request)
is_node_async = inspect.iscoroutinefunction(node_class.perform_request)
if is_transport_async != is_node_async:
raise ValueError(
f"Specified 'node_class' {'is' if is_node_async else 'is not'} async, "
f"should be {'async' if is_transport_async else 'sync'} instead"
)
validate_sniffing_options(
node_configs=node_configs,
sniff_on_start=sniff_on_start,
sniff_before_requests=sniff_before_requests,
sniff_on_node_failure=sniff_on_node_failure,
sniff_callback=sniff_callback,
)
# Create the default metadata for the x-elastic-client-meta
# HTTP header. Only requires adding the (service, service_version)
# tuple to the beginning of the client_meta
self._transport_client_meta: Tuple[Tuple[str, str], ...] = (
client_meta_service,
("py", client_meta_version(python_version())),
("t", client_meta_version(__version__)),
)
# Grab the 'HTTP_CLIENT_META' property from the node class
http_client_meta = cast(
Optional[Tuple[str, str]],
getattr(node_class, "_CLIENT_META_HTTP_CLIENT", None),
)
if http_client_meta:
self._transport_client_meta += (http_client_meta,)
if not isinstance(meta_header, bool):
raise TypeError("'meta_header' must be of type bool")
self.meta_header = meta_header
# serialization config
_serializers = DEFAULT_SERIALIZERS.copy()
# if custom serializers map has been supplied, override the defaults with it
if serializers:
_serializers.update(serializers)
# Create our collection of serializers
self.serializers = SerializerCollection(
_serializers, default_mimetype=default_mimetype
)
# Set of default request options
self.max_retries = max_retries
self.retry_on_status = retry_on_status
self.retry_on_timeout = retry_on_timeout
# Build the NodePool from all the options
node_pool_kwargs: Dict[str, Any] = {}
if node_selector_class is not None:
node_pool_kwargs["node_selector_class"] = node_selector_class
if dead_node_backoff_factor is not None:
node_pool_kwargs["dead_node_backoff_factor"] = dead_node_backoff_factor
if max_dead_node_backoff is not None:
node_pool_kwargs["max_dead_node_backoff"] = max_dead_node_backoff
self.node_pool: NodePool = node_pool_class(
node_configs,
node_class=node_class,
randomize_nodes=randomize_nodes_in_pool,
**node_pool_kwargs,
)
self._sniff_on_start = sniff_on_start
self._sniff_before_requests = sniff_before_requests
self._sniff_on_node_failure = sniff_on_node_failure
self._sniff_timeout = sniff_timeout
self._sniff_callback = sniff_callback
self._sniffing_lock = Lock() # Used to track whether we're currently sniffing.
self._min_delay_between_sniffing = min_delay_between_sniffing
self._last_sniffed_at = 0.0
if sniff_on_start:
self.sniff(True)
def perform_request( # type: ignore[return]
self,
method: str,
target: str,
*,
body: Optional[Any] = None,
headers: Union[Mapping[str, Any], DefaultType] = DEFAULT,
max_retries: Union[int, DefaultType] = DEFAULT,
retry_on_status: Union[Collection[int], DefaultType] = DEFAULT,
retry_on_timeout: Union[bool, DefaultType] = DEFAULT,
request_timeout: Union[Optional[float], DefaultType] = DEFAULT,
client_meta: Union[Tuple[Tuple[str, str], ...], DefaultType] = DEFAULT,
otel_span: Union[OpenTelemetrySpan, DefaultType] = DEFAULT,
) -> TransportApiResponse:
"""
Perform the actual request. Retrieve a node from the node
pool, pass all the information to it's perform_request method and
return the data.
If an exception was raised, mark the node as failed and retry (up
to ``max_retries`` times).
If the operation was successful and the node used was previously
marked as dead, mark it as live, resetting it's failure count.
:arg method: HTTP method to use
:arg target: HTTP request target
:arg body: body of the request, will be serialized using serializer and
passed to the node
:arg headers: Additional headers to send with the request.
:arg max_retries: Maximum number of retries before giving up on a request.
Set to ``0`` to disable retries.
:arg retry_on_status: Collection of HTTP status codes to retry.
:arg retry_on_timeout: Set to true to retry after timeout errors.
:arg request_timeout: Amount of time to wait for a response to fail with a timeout error.
:arg client_meta: Extra client metadata key-value pairs to send in the client meta header.
:arg otel_span: OpenTelemetry span used to add metadata to the span.
:returns: Tuple of the :class:`elastic_transport.ApiResponseMeta` with the deserialized response.
"""
if headers is DEFAULT:
request_headers = HttpHeaders()
else:
request_headers = HttpHeaders(headers)
max_retries = resolve_default(max_retries, self.max_retries)
retry_on_timeout = resolve_default(retry_on_timeout, self.retry_on_timeout)
retry_on_status = resolve_default(retry_on_status, self.retry_on_status)
otel_span = resolve_default(otel_span, OpenTelemetrySpan(None))
if self.meta_header:
request_headers["x-elastic-client-meta"] = ",".join(
f"{k}={v}"
for k, v in self._transport_client_meta
+ resolve_default(client_meta, ())
)
# Serialize the request body to bytes based on the given mimetype.
request_body: Optional[bytes]
if body is not None:
if "content-type" not in request_headers:
raise ValueError(
"Must provide a 'Content-Type' header to requests with bodies"
)
request_body = self.serializers.dumps(
body, mimetype=request_headers["content-type"]
)
otel_span.set_db_statement(request_body)
else:
request_body = None
# Errors are stored from (oldest->newest)
errors: List[Exception] = []
for attempt in range(max_retries + 1):
# If we sniff before requests are made we want to do so before
# 'node_pool.get()' is called so our sniffed nodes show up in the pool.
if self._sniff_before_requests:
self.sniff(False)
retry = False
node_failure = False
last_response: Optional[TransportApiResponse] = None
node = self.node_pool.get()
start_time = time.time()
try:
otel_span.set_node_metadata(node.host, node.port, node.base_url, target)
resp = node.perform_request(
method,
target,
body=request_body,
headers=request_headers,
request_timeout=request_timeout,
)
_logger.info(
"%s %s%s [status:%s duration:%.3fs]"
% (
method,
node.base_url,
target,
resp.meta.status,
time.time() - start_time,
)
)
if method != "HEAD":
body = self.serializers.loads(resp.body, resp.meta.mimetype)
else:
body = None
if resp.meta.status in retry_on_status:
retry = True
# Keep track of the last response we see so we can return
# it in case the retried request returns with a transport error.
last_response = TransportApiResponse(resp.meta, body)
except TransportError as e:
_logger.info(
"%s %s%s [status:%s duration:%.3fs]"
% (
method,
node.base_url,
target,
"N/A",
time.time() - start_time,
)
)
if isinstance(e, ConnectionTimeout):
retry = retry_on_timeout
node_failure = True
elif isinstance(e, ConnectionError):
retry = True
node_failure = True
# If the error was determined to be a node failure
# we mark it dead in the node pool to allow for
# other nodes to be retried.
if node_failure:
self.node_pool.mark_dead(node)
if self._sniff_on_node_failure:
try:
self.sniff(False)
except TransportError:
# If sniffing on failure, it could fail too. Catch the
# exception not to interrupt the retries.
pass
if not retry or attempt >= max_retries:
# Since we're exhausted but we have previously
# received some sort of response from the API
# we should forward that along instead of the
# transport error. Likely to be more actionable.
if last_response is not None:
return last_response
e.errors = tuple(errors)
raise
else:
_logger.warning(
"Retrying request after failure (attempt %d of %d)",
attempt,
max_retries,
exc_info=e,
)
errors.append(e)
else:
# If we got back a response we need to check if that status
# is indicative of a healthy node even if it's a non-2XX status
if (
200 <= resp.meta.status < 299
or resp.meta.status in NOT_DEAD_NODE_HTTP_STATUSES
):
self.node_pool.mark_live(node)
else:
self.node_pool.mark_dead(node)
if self._sniff_on_node_failure:
try:
self.sniff(False)
except TransportError:
# If sniffing on failure, it could fail too. Catch the
# exception not to interrupt the retries.
pass
# We either got a response we're happy with or
# we've exhausted all of our retries so we return it.
if not retry or attempt >= max_retries:
return TransportApiResponse(resp.meta, body)
else:
_logger.warning(
"Retrying request after non-successful status %d (attempt %d of %d)",
resp.meta.status,
attempt,
max_retries,
)
def sniff(self, is_initial_sniff: bool = False) -> None:
previously_sniffed_at = self._last_sniffed_at
should_sniff = self._should_sniff(is_initial_sniff)
try:
if should_sniff:
_logger.info("Started sniffing for additional nodes")
self._last_sniffed_at = time.time()
options = SniffOptions(
is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout
)
assert self._sniff_callback is not None
node_configs = self._sniff_callback(self, options)
if not node_configs and is_initial_sniff:
raise SniffingError(
"No viable nodes were discovered on the initial sniff attempt"
)
prev_node_pool_size = len(self.node_pool)
for node_config in node_configs:
self.node_pool.add(node_config)
# Do some math to log which nodes are new/existing
sniffed_nodes = len(node_configs)
new_nodes = sniffed_nodes - (len(self.node_pool) - prev_node_pool_size)
existing_nodes = sniffed_nodes - new_nodes
_logger.debug(
"Discovered %d nodes during sniffing (%d new nodes, %d already in pool)",
sniffed_nodes,
new_nodes,
existing_nodes,
)
# If sniffing failed for any reason we
# want to allow retrying immediately.
except Exception as e:
_logger.warning("Encountered an error during sniffing", exc_info=e)
self._last_sniffed_at = previously_sniffed_at
raise
# If we started a sniff we need to release the lock.
finally:
if should_sniff:
self._sniffing_lock.release()
def close(self) -> None:
"""
Explicitly closes all nodes in the transport's pool
"""
for node in self.node_pool.all():
node.close()
def _should_sniff(self, is_initial_sniff: bool) -> bool:
"""Decide if we should sniff or not. If we return ``True`` from this
method the caller has a responsibility to unlock the ``_sniffing_lock``
"""
if not is_initial_sniff and (
time.time() - self._last_sniffed_at < self._min_delay_between_sniffing
):
return False
return self._sniffing_lock.acquire(False)
def validate_sniffing_options(
*,
node_configs: List[NodeConfig],
sniff_before_requests: bool,
sniff_on_start: bool,
sniff_on_node_failure: bool,
sniff_callback: Optional[Any],
) -> None:
"""Validates the Transport configurations for sniffing"""
sniffing_enabled = sniff_before_requests or sniff_on_start or sniff_on_node_failure
if sniffing_enabled and not sniff_callback:
raise ValueError("Enabling sniffing requires specifying a 'sniff_callback'")
if not sniffing_enabled and sniff_callback:
raise ValueError(
"Using 'sniff_callback' requires enabling sniffing via 'sniff_on_start', "
"'sniff_before_requests' or 'sniff_on_node_failure'"
)
# If we're sniffing we want to warn the user for non-homogenous NodeConfigs.
if sniffing_enabled and len(node_configs) > 1:
warn_if_varying_node_config_options(node_configs)
def warn_if_varying_node_config_options(node_configs: List[NodeConfig]) -> None:
"""Function which detects situations when sniffing may produce incorrect configs"""
exempt_attrs = {"host", "port", "connections_per_node", "_extras", "ssl_context"}
match_attr_dict = None
for node_config in node_configs:
attr_dict = {
field.name: getattr(node_config, field.name)
for field in dataclasses.fields(node_config)
if field.name not in exempt_attrs
}
if match_attr_dict is None:
match_attr_dict = attr_dict
# Detected two nodes that have different config, warn the user.
elif match_attr_dict != attr_dict:
warnings.warn(
"Detected NodeConfig instances with different options. "
"It's recommended to keep all options except for "
"'host' and 'port' the same for sniffing to work reliably.",
category=TransportWarning,
stacklevel=warn_stacklevel(),
)