esrally/client/asynchronous.py (311 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import json
import logging
import warnings
from collections.abc import Iterable, Mapping
from typing import Any, Optional
import aiohttp
from aiohttp import BaseConnector, RequestInfo
from aiohttp.client_proto import ResponseHandler
from aiohttp.helpers import BaseTimerContext
from elastic_transport import (
AiohttpHttpNode,
ApiResponse,
AsyncTransport,
BinaryApiResponse,
HeadApiResponse,
ListApiResponse,
ObjectApiResponse,
TextApiResponse,
)
from elastic_transport.client_utils import DEFAULT
from elasticsearch import AsyncElasticsearch
from elasticsearch._async.client import IlmClient
from elasticsearch.compat import warn_stacklevel
from elasticsearch.exceptions import HTTP_EXCEPTIONS, ApiError, ElasticsearchWarning
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
from esrally.client.common import _WARNING_RE, _mimetype_header_to_compat, _quote_query
from esrally.client.context import RequestContextHolder
from esrally.utils import io, versions
class StaticTransport:
def __init__(self):
self.closed = False
def is_closing(self):
return False
def close(self):
self.closed = True
def abort(self):
self.close()
class StaticConnector(BaseConnector):
async def _create_connection(self, req: "ClientRequest", traces: list["Trace"], timeout: "ClientTimeout") -> ResponseHandler:
handler = ResponseHandler(self._loop)
handler.transport = StaticTransport()
handler.protocol = ""
return handler
class StaticRequest(aiohttp.ClientRequest):
RESPONSES = None
async def send(self, conn: "Connection") -> "ClientResponse":
self.response = self.response_class(
self.method,
self.original_url,
writer=self._writer,
continue100=self._continue,
timer=self._timer,
request_info=self.request_info,
traces=self._traces,
loop=self.loop,
session=self._session,
)
path = self.original_url.path
self.response.static_body = StaticRequest.RESPONSES.response(path)
return self.response
# we use EmptyStreamReader here because it overrides all methods with
# no-op implementations that we need.
class StaticStreamReader(aiohttp.streams.EmptyStreamReader):
def __init__(self, body):
super().__init__()
self.body = body
async def read(self, n: int = -1) -> bytes:
return self.body.encode("utf-8")
class StaticResponse(aiohttp.ClientResponse):
def __init__(
self,
method: str,
url: URL,
*,
writer: "asyncio.Task[None]",
continue100: Optional["asyncio.Future[bool]"],
timer: BaseTimerContext,
request_info: RequestInfo,
traces: list["Trace"],
loop: asyncio.AbstractEventLoop,
session: "ClientSession",
) -> None:
super().__init__(
method,
url,
writer=writer,
continue100=continue100,
timer=timer,
request_info=request_info,
traces=traces,
loop=loop,
session=session,
)
self.static_body = None
async def start(self, connection: "Connection") -> "ClientResponse":
self._closed = False
self._protocol = connection.protocol
self._connection = connection
self._headers = CIMultiDictProxy(CIMultiDict())
self.content = StaticStreamReader(self.static_body)
self.status = 200
return self
class ResponseMatcher:
def __init__(self, responses):
self.logger = logging.getLogger(__name__)
self.responses = []
for response in responses:
path = response["path"]
if path == "*":
matcher = ResponseMatcher.always()
elif path.startswith("*"):
matcher = ResponseMatcher.endswith(path[1:])
elif path.endswith("*"):
matcher = ResponseMatcher.startswith(path[:-1])
else:
matcher = ResponseMatcher.equals(path)
body = json.dumps(response["body"])
self.responses.append((path, matcher, body))
@staticmethod
def always():
def f(p):
return True
return f
@staticmethod
def startswith(path_pattern):
def f(p):
return p.startswith(path_pattern)
return f
@staticmethod
def endswith(path_pattern):
def f(p):
return p.endswith(path_pattern)
return f
@staticmethod
def equals(path_pattern):
def f(p):
return p == path_pattern
return f
def response(self, path):
for path_pattern, matcher, body in self.responses:
if matcher(path):
self.logger.debug("Path pattern [%s] matches path [%s].", path_pattern, path)
return body
class RallyTCPConnector(aiohttp.TCPConnector):
def __init__(self, *args, **kwargs):
self.client_id = kwargs.pop("client_id", None)
self.logger = logging.getLogger(__name__)
super().__init__(*args, **kwargs)
async def _resolve_host(self, *args, **kwargs):
hosts = await super()._resolve_host(*args, **kwargs)
self.logger.debug("client id [%s] resolved hosts [{%s}]", self.client_id, hosts)
# super()._resolve_host() does actually return all the IPs a given name resolves to, but the underlying
# super()._create_direct_connection() logic only ever selects the first succesful host from this list from which
# to establish a connection
#
# here we use the factory assigned client_id to deterministically return a IP from this list, which we then swap
# to the beginning of the list to evenly distribute connections across _all_ clients
# see https://github.com/elastic/rally/issues/1598
idx = self.client_id % len(hosts)
host = hosts[idx]
self.logger.debug("client id [%s] selected host [{%s}]", self.client_id, host)
# swap order of hosts
hosts[0], hosts[idx] = hosts[idx], hosts[0]
return hosts
class RallyAiohttpHttpNode(AiohttpHttpNode):
def __init__(self, config):
super().__init__(config)
self._loop = None
self.client_id = None
self.trace_configs = None
self.enable_cleanup_closed = False
self._static_responses = None
self._request_class = aiohttp.ClientRequest
self._response_class = aiohttp.ClientResponse
@property
def static_responses(self):
return self._static_responses
@static_responses.setter
def static_responses(self, static_responses):
self._static_responses = static_responses
if self._static_responses:
# read static responses once and reuse them
if not StaticRequest.RESPONSES:
with open(io.normalize_path(self._static_responses)) as f:
StaticRequest.RESPONSES = ResponseMatcher(json.load(f))
self._request_class = StaticRequest
self._response_class = StaticResponse
def _create_aiohttp_session(self):
if self._loop is None:
self._loop = asyncio.get_running_loop()
if self._static_responses:
connector = StaticConnector(limit_per_host=self._connections_per_node, enable_cleanup_closed=self.enable_cleanup_closed)
else:
connector = RallyTCPConnector(
limit_per_host=self._connections_per_node,
use_dns_cache=True,
ssl=self._ssl_context,
enable_cleanup_closed=self.enable_cleanup_closed,
client_id=self.client_id,
)
self.session = aiohttp.ClientSession(
headers=self.headers,
auto_decompress=True,
loop=self._loop,
cookie_jar=aiohttp.DummyCookieJar(),
request_class=self._request_class,
response_class=self._response_class,
connector=connector,
trace_configs=self.trace_configs,
)
class RallyAsyncTransport(AsyncTransport):
def __init__(self, *args, **kwargs):
super().__init__(*args, node_class=RallyAiohttpHttpNode, **kwargs)
class RallyIlmClient(IlmClient):
async def put_lifecycle(self, *args, **kwargs):
"""
The 'elasticsearch-py' 8.x method signature renames the 'policy' param to 'name', and the previously so-called
'body' param becomes 'policy'
"""
if args:
kwargs["name"] = args[0]
if body := kwargs.pop("body", None):
kwargs["policy"] = body.get("policy", {})
# pylint: disable=missing-kwoa
return await IlmClient.put_lifecycle(self, **kwargs)
class RallyAsyncElasticsearch(AsyncElasticsearch, RequestContextHolder):
def __init__(self, *args, **kwargs):
distribution_version = kwargs.pop("distribution_version", None)
distribution_flavor = kwargs.pop("distribution_flavor", None)
super().__init__(*args, **kwargs)
# skip verification at this point; we've already verified this earlier with the synchronous client.
# The async client is used in the hot code path and we use customized overrides (such as that we don't
# parse response bodies in some cases for performance reasons, e.g. when using the bulk API).
self._verified_elasticsearch = True
self.distribution_version = distribution_version
self.distribution_flavor = distribution_flavor
# some ILM method signatures changed in 'elasticsearch-py' 8.x,
# so we override method(s) here to provide BWC for any custom
# runners that aren't using the new kwargs
self.ilm = RallyIlmClient(self)
@property
def is_serverless(self):
return versions.is_serverless(self.distribution_flavor)
def options(self, *args, **kwargs):
new_self = super().options(*args, **kwargs)
new_self.distribution_version = self.distribution_version
new_self.distribution_flavor = self.distribution_flavor
return new_self
async def perform_request(
self,
method: str,
path: str,
*,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Any] = None,
) -> ApiResponse[Any]:
# We need to ensure that we provide content-type and accept headers
if body is not None:
if headers is None:
headers = {"content-type": "application/json", "accept": "application/json"}
else:
if headers.get("content-type") is None:
headers["content-type"] = "application/json"
if headers.get("accept") is None:
headers["accept"] = "application/json"
if headers:
request_headers = self._headers.copy()
request_headers.update(headers)
else:
request_headers = self._headers
# Converts all parts of a Accept/Content-Type headers
# from application/X -> application/vnd.elasticsearch+X
# see https://github.com/elastic/elasticsearch/issues/51816
# Not applicable to serverless
if not self.is_serverless:
if versions.is_version_identifier(self.distribution_version) and (
versions.Version.from_string(self.distribution_version) >= versions.Version.from_string("8.0.0")
):
_mimetype_header_to_compat("Accept", request_headers)
_mimetype_header_to_compat("Content-Type", request_headers)
if params:
target = f"{path}?{_quote_query(params)}"
else:
target = path
meta, resp_body = await self.transport.perform_request(
method,
target,
headers=request_headers,
body=body,
request_timeout=self._request_timeout,
max_retries=self._max_retries,
retry_on_status=self._retry_on_status,
retry_on_timeout=self._retry_on_timeout,
client_meta=self._client_meta,
)
# HEAD with a 404 is returned as a normal response
# since this is used as an 'exists' functionality.
if not (method == "HEAD" and meta.status == 404) and (
not 200 <= meta.status < 299
and (self._ignore_status is DEFAULT or self._ignore_status is None or meta.status not in self._ignore_status)
):
message = str(resp_body)
# If the response is an error response try parsing
# the raw Elasticsearch error before raising.
if isinstance(resp_body, dict):
try:
error = resp_body.get("error", message)
if isinstance(error, dict) and "type" in error:
error = error["type"]
message = error
except (ValueError, KeyError, TypeError):
pass
raise HTTP_EXCEPTIONS.get(meta.status, ApiError)(message=message, meta=meta, body=resp_body)
# 'Warning' headers should be reraised as 'ElasticsearchWarning'
if "warning" in meta.headers:
warning_header = (meta.headers.get("warning") or "").strip()
warning_messages: Iterable[str] = _WARNING_RE.findall(warning_header) or (warning_header,)
stacklevel = warn_stacklevel()
for warning_message in warning_messages:
warnings.warn(
warning_message,
category=ElasticsearchWarning,
stacklevel=stacklevel,
)
if method == "HEAD":
response = HeadApiResponse(meta=meta)
elif isinstance(resp_body, dict):
response = ObjectApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
elif isinstance(resp_body, list):
response = ListApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
elif isinstance(resp_body, str):
response = TextApiResponse( # type: ignore[assignment]
body=resp_body,
meta=meta,
)
elif isinstance(resp_body, bytes):
response = BinaryApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
else:
response = ApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
return response