elastic_transport/_node/_http_aiohttp.py (216 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import base64 import functools import gzip import os import re import ssl import sys import warnings from typing import Optional, TypedDict, Union from .._compat import warn_stacklevel from .._exceptions import ConnectionError, ConnectionTimeout, SecurityWarning, TlsError from .._models import ApiResponseMeta, HttpHeaders, NodeConfig from ..client_utils import DEFAULT, DefaultType, client_meta_version from ._base import ( BUILTIN_EXCEPTIONS, DEFAULT_CA_CERTS, RERAISE_EXCEPTIONS, NodeApiResponse, ssl_context_from_node_config, ) from ._base_async import BaseAsyncNode try: import aiohttp import aiohttp.client_exceptions as aiohttp_exceptions _AIOHTTP_AVAILABLE = True _AIOHTTP_META_VERSION = client_meta_version(aiohttp.__version__) _version_parts = [] for _version_part in aiohttp.__version__.split(".")[:3]: try: _version_parts.append(int(re.search(r"^([0-9]+)", _version_part).group(1))) # type: ignore[union-attr] except (AttributeError, ValueError): break _AIOHTTP_SEMVER_VERSION = tuple(_version_parts) # See aio-libs/aiohttp#1769 and #5012 _AIOHTTP_FIXED_HEAD_BUG = _AIOHTTP_SEMVER_VERSION >= (3, 7, 0) class RequestKwarg(TypedDict, total=False): ssl: aiohttp.Fingerprint except ImportError: # pragma: nocover _AIOHTTP_AVAILABLE = False _AIOHTTP_META_VERSION = "" _AIOHTTP_FIXED_HEAD_BUG = False # Avoid aiohttp enabled_cleanup_closed warning: https://github.com/aio-libs/aiohttp/pull/9726 _NEEDS_CLEANUP_CLOSED_313 = (3, 13, 0) <= sys.version_info < (3, 13, 1) _NEEDS_CLEANUP_CLOSED = _NEEDS_CLEANUP_CLOSED_313 or sys.version_info < (3, 12, 7) class AiohttpHttpNode(BaseAsyncNode): """Default asynchronous node class using the ``aiohttp`` library via HTTP""" _CLIENT_META_HTTP_CLIENT = ("ai", _AIOHTTP_META_VERSION) def __init__(self, config: NodeConfig): if not _AIOHTTP_AVAILABLE: # pragma: nocover raise ValueError("You must have 'aiohttp' installed to use AiohttpHttpNode") super().__init__(config) self._ssl_assert_fingerprint = config.ssl_assert_fingerprint ssl_context: Optional[ssl.SSLContext] = None if config.scheme == "https": if config.ssl_context is not None: ssl_context = ssl_context_from_node_config(config) else: ssl_context = ssl_context_from_node_config(config) ca_certs = ( DEFAULT_CA_CERTS if config.ca_certs is None else config.ca_certs ) if config.verify_certs: if not ca_certs: raise ValueError( "Root certificates are missing for certificate " "validation. Either pass them in using the ca_certs parameter or " "install certifi to use it automatically." ) else: if config.ssl_show_warn: warnings.warn( f"Connecting to {self.base_url!r} using TLS with verify_certs=False is insecure", stacklevel=warn_stacklevel(), category=SecurityWarning, ) if ca_certs is not None: if os.path.isfile(ca_certs): ssl_context.load_verify_locations(cafile=ca_certs) elif os.path.isdir(ca_certs): ssl_context.load_verify_locations(capath=ca_certs) else: raise ValueError("ca_certs parameter is not a path") # Use client_cert and client_key variables for SSL certificate configuration. if config.client_cert and not os.path.isfile(config.client_cert): raise ValueError("client_cert is not a path to a file") if config.client_key and not os.path.isfile(config.client_key): raise ValueError("client_key is not a path to a file") if config.client_cert and config.client_key: ssl_context.load_cert_chain(config.client_cert, config.client_key) elif config.client_cert: ssl_context.load_cert_chain(config.client_cert) self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment] self.session: Optional[aiohttp.ClientSession] = None # Parameters for creating an aiohttp.ClientSession later. self._connections_per_node = config.connections_per_node self._ssl_context = ssl_context async def perform_request( # type: ignore[override] self, method: str, target: str, body: Optional[bytes] = None, headers: Optional[HttpHeaders] = None, request_timeout: Union[DefaultType, Optional[float]] = DEFAULT, ) -> NodeApiResponse: global _AIOHTTP_FIXED_HEAD_BUG if self.session is None: self._create_aiohttp_session() assert self.session is not None url = self.base_url + target is_head = False # There is a bug in aiohttp<3.7 that disables the re-use # of the connection in the pool when method=HEAD. # See: aio-libs/aiohttp#1769 if method == "HEAD" and not _AIOHTTP_FIXED_HEAD_BUG: method = "GET" is_head = True # total=0 means no timeout for aiohttp resolved_timeout: Optional[float] = ( self.config.request_timeout if request_timeout is DEFAULT else request_timeout ) aiohttp_timeout = aiohttp.ClientTimeout( total=resolved_timeout if resolved_timeout is not None else 0 ) request_headers = self._headers.copy() if headers: request_headers.update(headers) body_to_send: Optional[bytes] if body: if self._http_compress: body_to_send = gzip.compress(body) request_headers["content-encoding"] = "gzip" else: body_to_send = body else: body_to_send = None kwargs: RequestKwarg = {} if self._ssl_assert_fingerprint: kwargs["ssl"] = aiohttp_fingerprint(self._ssl_assert_fingerprint) try: start = self._loop.time() async with self.session.request( method, url, data=body_to_send, headers=request_headers, timeout=aiohttp_timeout, **kwargs, ) as response: if is_head: # We actually called 'GET' so throw away the data. await response.release() raw_data = b"" else: raw_data = await response.read() duration = self._loop.time() - start # We want to reraise a cancellation or recursion error. except RERAISE_EXCEPTIONS: raise except Exception as e: err: Exception if isinstance( e, (asyncio.TimeoutError, aiohttp_exceptions.ServerTimeoutError) ): err = ConnectionTimeout( "Connection timed out during request", errors=(e,) ) elif isinstance(e, (ssl.SSLError, aiohttp_exceptions.ClientSSLError)): err = TlsError(str(e), errors=(e,)) elif isinstance(e, BUILTIN_EXCEPTIONS): raise else: err = ConnectionError(str(e), errors=(e,)) self._log_request( method="HEAD" if is_head else method, target=target, headers=request_headers, body=body, exception=err, ) raise err from None meta = ApiResponseMeta( node=self.config, duration=duration, http_version="1.1", status=response.status, headers=HttpHeaders(response.headers), ) self._log_request( method="HEAD" if is_head else method, target=target, headers=request_headers, body=body, meta=meta, response=raw_data, ) return NodeApiResponse( meta, raw_data, ) async def close(self) -> None: # type: ignore[override] if self.session: await self.session.close() self.session = None def _create_aiohttp_session(self) -> None: """Creates an aiohttp.ClientSession(). This is delayed until the first call to perform_request() so that AsyncTransport has a chance to set AiohttpHttpNode.loop """ if self._loop is None: self._loop = asyncio.get_running_loop() self.session = aiohttp.ClientSession( headers=self.headers, skip_auto_headers=("accept", "accept-encoding", "user-agent"), auto_decompress=True, loop=self._loop, cookie_jar=aiohttp.DummyCookieJar(), connector=aiohttp.TCPConnector( limit_per_host=self._connections_per_node, use_dns_cache=True, enable_cleanup_closed=_NEEDS_CLEANUP_CLOSED, ssl=self._ssl_context or False, ), ) @functools.lru_cache(maxsize=64, typed=True) def aiohttp_fingerprint(ssl_assert_fingerprint: str) -> "aiohttp.Fingerprint": """Changes 'ssl_assert_fingerprint' into a configured 'aiohttp.Fingerprint' instance. Uses a cache to prevent creating tons of objects needlessly. """ return aiohttp.Fingerprint( base64.b16decode(ssl_assert_fingerprint.replace(":", ""), casefold=True) )