elasticsearch/_async/client/_base.py (352 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import warnings
from typing import (
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
from elastic_transport import (
ApiResponse,
AsyncTransport,
BinaryApiResponse,
HeadApiResponse,
HttpHeaders,
ListApiResponse,
NodeConfig,
ObjectApiResponse,
OpenTelemetrySpan,
SniffOptions,
TextApiResponse,
)
from elastic_transport.client_utils import DEFAULT, DefaultType
from ..._otel import OpenTelemetry
from ..._version import __versionstr__
from ...compat import warn_stacklevel
from ...exceptions import (
HTTP_EXCEPTIONS,
ApiError,
ConnectionError,
ElasticsearchWarning,
SerializationError,
UnsupportedProductError,
)
from .utils import _TYPE_ASYNC_SNIFF_CALLBACK, _base64_auth_header, _quote_query
_WARNING_RE = re.compile(r"\"([^\"]*)\"")
_COMPAT_MIMETYPE_TEMPLATE = "application/vnd.elasticsearch+%s; compatible-with=" + str(
__versionstr__.partition(".")[0]
)
_COMPAT_MIMETYPE_RE = re.compile(r"application/(json|x-ndjson|vnd\.mapbox-vector-tile)")
_COMPAT_MIMETYPE_SUB = _COMPAT_MIMETYPE_TEMPLATE % (r"\g<1>",)
def resolve_auth_headers(
headers: Optional[Mapping[str, str]],
http_auth: Union[DefaultType, None, Tuple[str, str], str] = DEFAULT,
api_key: Union[DefaultType, None, Tuple[str, str], str] = DEFAULT,
basic_auth: Union[DefaultType, None, Tuple[str, str], str] = DEFAULT,
bearer_auth: Union[DefaultType, None, str] = DEFAULT,
) -> HttpHeaders:
if headers is None:
headers = HttpHeaders()
elif not isinstance(headers, HttpHeaders):
headers = HttpHeaders(headers)
resolved_http_auth = http_auth if http_auth is not DEFAULT else None
resolved_basic_auth = basic_auth if basic_auth is not DEFAULT else None
if resolved_http_auth is not None:
if resolved_basic_auth is not None:
raise ValueError(
"Can't specify both 'http_auth' and 'basic_auth', "
"instead only specify 'basic_auth'"
)
if isinstance(http_auth, str) or (
isinstance(resolved_http_auth, (list, tuple))
and all(isinstance(x, str) for x in resolved_http_auth)
):
resolved_basic_auth = resolved_http_auth
else:
raise TypeError(
"The deprecated 'http_auth' parameter must be either 'Tuple[str, str]' or 'str'. "
"Use either the 'basic_auth' parameter instead"
)
warnings.warn(
"The 'http_auth' parameter is deprecated. "
"Use 'basic_auth' or 'bearer_auth' parameters instead",
category=DeprecationWarning,
stacklevel=warn_stacklevel(),
)
resolved_api_key = api_key if api_key is not DEFAULT else None
resolved_bearer_auth = bearer_auth if bearer_auth is not DEFAULT else None
if resolved_api_key or resolved_basic_auth or resolved_bearer_auth:
if (
sum(
x is not None
for x in (
resolved_api_key,
resolved_basic_auth,
resolved_bearer_auth,
)
)
> 1
):
raise ValueError(
"Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'"
)
if headers and headers.get("authorization", None) is not None:
raise ValueError(
"Can't set 'Authorization' HTTP header with other authentication options"
)
if resolved_api_key:
headers["authorization"] = f"ApiKey {_base64_auth_header(resolved_api_key)}"
if resolved_basic_auth:
headers["authorization"] = (
f"Basic {_base64_auth_header(resolved_basic_auth)}"
)
if resolved_bearer_auth:
headers["authorization"] = f"Bearer {resolved_bearer_auth}"
return headers
def create_sniff_callback(
host_info_callback: Optional[
Callable[[Dict[str, Any], Dict[str, Any]], Optional[Dict[str, Any]]]
] = None,
sniffed_node_callback: Optional[
Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]]
] = None,
) -> _TYPE_ASYNC_SNIFF_CALLBACK:
assert (host_info_callback is None) != (sniffed_node_callback is None)
# Wrap the deprecated 'host_info_callback' into 'sniffed_node_callback'
if host_info_callback is not None:
def _sniffed_node_callback(
node_info: Dict[str, Any], node_config: NodeConfig
) -> Optional[NodeConfig]:
assert host_info_callback is not None
if (
host_info_callback( # type ignore[misc]
node_info, {"host": node_config.host, "port": node_config.port}
)
is None
):
return None
return node_config
sniffed_node_callback = _sniffed_node_callback
async def sniff_callback(
transport: AsyncTransport, sniff_options: SniffOptions
) -> List[NodeConfig]:
for _ in transport.node_pool.all():
try:
meta, node_infos = await transport.perform_request(
"GET",
"/_nodes/_all/http",
headers={
"accept": "application/vnd.elasticsearch+json; compatible-with=9"
},
request_timeout=(
sniff_options.sniff_timeout
if not sniff_options.is_initial_sniff
else None
),
)
except (SerializationError, ConnectionError):
continue
if not 200 <= meta.status <= 299:
continue
node_configs = []
for node_info in node_infos.get("nodes", {}).values():
address = node_info.get("http", {}).get("publish_address")
if not address or ":" not in address:
continue
if "/" in address:
# Support 7.x host/ip:port behavior where http.publish_host has been set.
fqdn, ipaddress = address.split("/", 1)
host = fqdn
_, port_str = ipaddress.rsplit(":", 1)
port = int(port_str)
else:
host, port_str = address.rsplit(":", 1)
port = int(port_str)
assert sniffed_node_callback is not None
sniffed_node = sniffed_node_callback(
node_info, meta.node.replace(host=host, port=port)
)
if sniffed_node is None:
continue
# Use the node which was able to make the request as a base.
node_configs.append(sniffed_node)
if node_configs:
return node_configs
return []
return sniff_callback
def _default_sniffed_node_callback(
node_info: Dict[str, Any], node_config: NodeConfig
) -> Optional[NodeConfig]:
if node_info.get("roles", []) == ["master"]:
return None
return node_config
default_sniff_callback = create_sniff_callback(
sniffed_node_callback=_default_sniffed_node_callback
)
class BaseClient:
def __init__(self, _transport: AsyncTransport) -> None:
self._transport = _transport
self._client_meta: Union[DefaultType, Tuple[Tuple[str, str], ...]] = DEFAULT
self._headers = HttpHeaders()
self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT
self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT
self._max_retries: Union[DefaultType, int] = DEFAULT
self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT
self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT
self._verified_elasticsearch = False
self._otel = OpenTelemetry()
@property
def transport(self) -> AsyncTransport:
return self._transport
async def perform_request(
self,
method: str,
path: str,
*,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Any] = None,
endpoint_id: Optional[str] = None,
path_parts: Optional[Mapping[str, Any]] = None,
) -> ApiResponse[Any]:
with self._otel.span(
method,
endpoint_id=endpoint_id,
path_parts=path_parts or {},
) as otel_span:
response = await self._perform_request(
method,
path,
params=params,
headers=headers,
body=body,
otel_span=otel_span,
)
otel_span.set_elastic_cloud_metadata(response.meta.headers)
return response
async def _perform_request(
self,
method: str,
path: str,
*,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Any] = None,
otel_span: OpenTelemetrySpan,
) -> ApiResponse[Any]:
if headers:
request_headers = self._headers.copy()
request_headers.update(headers)
else:
request_headers = self._headers
def mimetype_header_to_compat(header: str) -> None:
# Converts all parts of a Accept/Content-Type headers
# from application/X -> application/vnd.elasticsearch+X
mimetype = request_headers.get(header, None)
if mimetype:
request_headers[header] = _COMPAT_MIMETYPE_RE.sub(
_COMPAT_MIMETYPE_SUB, mimetype
)
mimetype_header_to_compat("Accept")
mimetype_header_to_compat("Content-Type")
if params:
target = f"{path}?{_quote_query(params)}"
else:
target = path
meta, resp_body = await self.transport.perform_request(
method,
target,
headers=request_headers,
body=body,
request_timeout=self._request_timeout,
max_retries=self._max_retries,
retry_on_status=self._retry_on_status,
retry_on_timeout=self._retry_on_timeout,
client_meta=self._client_meta,
otel_span=otel_span,
)
# HEAD with a 404 is returned as a normal response
# since this is used as an 'exists' functionality.
if not (method == "HEAD" and meta.status == 404) and (
not 200 <= meta.status < 299
and (
self._ignore_status is DEFAULT
or self._ignore_status is None
or meta.status not in self._ignore_status
)
):
message = str(resp_body)
# If the response is an error response try parsing
# the raw Elasticsearch error before raising.
if isinstance(resp_body, dict):
try:
error = resp_body.get("error", message)
if isinstance(error, dict) and "type" in error:
error = error["type"]
message = error
except (ValueError, KeyError, TypeError):
pass
raise HTTP_EXCEPTIONS.get(meta.status, ApiError)(
message=message, meta=meta, body=resp_body
)
# 'X-Elastic-Product: Elasticsearch' should be on every 2XX response.
if not self._verified_elasticsearch:
# If the header is set we mark the server as verified.
if meta.headers.get("x-elastic-product", "") == "Elasticsearch":
self._verified_elasticsearch = True
# Otherwise we only raise an error on 2XX responses.
elif meta.status >= 200 and meta.status < 300:
raise UnsupportedProductError(
message=(
"The client noticed that the server is not Elasticsearch "
"and we do not support this unknown product"
),
meta=meta,
body=resp_body,
)
# 'Warning' headers should be reraised as 'ElasticsearchWarning'
if "warning" in meta.headers:
warning_header = (meta.headers.get("warning") or "").strip()
warning_messages: Iterable[str] = _WARNING_RE.findall(warning_header) or (
warning_header,
)
stacklevel = warn_stacklevel()
for warning_message in warning_messages:
warnings.warn(
warning_message,
category=ElasticsearchWarning,
stacklevel=stacklevel,
)
if method == "HEAD":
response = HeadApiResponse(meta=meta)
elif isinstance(resp_body, dict):
response = ObjectApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
elif isinstance(resp_body, list):
response = ListApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
elif isinstance(resp_body, str):
response = TextApiResponse( # type: ignore[assignment]
body=resp_body,
meta=meta,
)
elif isinstance(resp_body, bytes):
response = BinaryApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
else:
response = ApiResponse(body=resp_body, meta=meta) # type: ignore[assignment]
return response
class NamespacedClient(BaseClient):
def __init__(self, client: "BaseClient") -> None:
self._client = client
super().__init__(self._client.transport)
async def perform_request(
self,
method: str,
path: str,
*,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Any] = None,
endpoint_id: Optional[str] = None,
path_parts: Optional[Mapping[str, Any]] = None,
) -> ApiResponse[Any]:
# Use the internal clients .perform_request() implementation
# so we take advantage of their transport options.
return await self._client.perform_request(
method,
path,
params=params,
headers=headers,
body=body,
endpoint_id=endpoint_id,
path_parts=path_parts,
)