elasticsearch/_async/client/_base.py (352 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import re import warnings from typing import ( Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Tuple, Union, ) from elastic_transport import ( ApiResponse, AsyncTransport, BinaryApiResponse, HeadApiResponse, HttpHeaders, ListApiResponse, NodeConfig, ObjectApiResponse, OpenTelemetrySpan, SniffOptions, TextApiResponse, ) from elastic_transport.client_utils import DEFAULT, DefaultType from ..._otel import OpenTelemetry from ..._version import __versionstr__ from ...compat import warn_stacklevel from ...exceptions import ( HTTP_EXCEPTIONS, ApiError, ConnectionError, ElasticsearchWarning, SerializationError, UnsupportedProductError, ) from .utils import _TYPE_ASYNC_SNIFF_CALLBACK, _base64_auth_header, _quote_query _WARNING_RE = re.compile(r"\"([^\"]*)\"") _COMPAT_MIMETYPE_TEMPLATE = "application/vnd.elasticsearch+%s; compatible-with=" + str( __versionstr__.partition(".")[0] ) _COMPAT_MIMETYPE_RE = re.compile(r"application/(json|x-ndjson|vnd\.mapbox-vector-tile)") _COMPAT_MIMETYPE_SUB = _COMPAT_MIMETYPE_TEMPLATE % (r"\g<1>",) def resolve_auth_headers( headers: Optional[Mapping[str, str]], http_auth: Union[DefaultType, None, Tuple[str, str], str] = DEFAULT, api_key: Union[DefaultType, None, Tuple[str, str], str] = DEFAULT, basic_auth: Union[DefaultType, None, Tuple[str, str], str] = DEFAULT, bearer_auth: Union[DefaultType, None, str] = DEFAULT, ) -> HttpHeaders: if headers is None: headers = HttpHeaders() elif not isinstance(headers, HttpHeaders): headers = HttpHeaders(headers) resolved_http_auth = http_auth if http_auth is not DEFAULT else None resolved_basic_auth = basic_auth if basic_auth is not DEFAULT else None if resolved_http_auth is not None: if resolved_basic_auth is not None: raise ValueError( "Can't specify both 'http_auth' and 'basic_auth', " "instead only specify 'basic_auth'" ) if isinstance(http_auth, str) or ( isinstance(resolved_http_auth, (list, tuple)) and all(isinstance(x, str) for x in resolved_http_auth) ): resolved_basic_auth = resolved_http_auth else: raise TypeError( "The deprecated 'http_auth' parameter must be either 'Tuple[str, str]' or 'str'. " "Use either the 'basic_auth' parameter instead" ) warnings.warn( "The 'http_auth' parameter is deprecated. " "Use 'basic_auth' or 'bearer_auth' parameters instead", category=DeprecationWarning, stacklevel=warn_stacklevel(), ) resolved_api_key = api_key if api_key is not DEFAULT else None resolved_bearer_auth = bearer_auth if bearer_auth is not DEFAULT else None if resolved_api_key or resolved_basic_auth or resolved_bearer_auth: if ( sum( x is not None for x in ( resolved_api_key, resolved_basic_auth, resolved_bearer_auth, ) ) > 1 ): raise ValueError( "Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'" ) if headers and headers.get("authorization", None) is not None: raise ValueError( "Can't set 'Authorization' HTTP header with other authentication options" ) if resolved_api_key: headers["authorization"] = f"ApiKey {_base64_auth_header(resolved_api_key)}" if resolved_basic_auth: headers["authorization"] = ( f"Basic {_base64_auth_header(resolved_basic_auth)}" ) if resolved_bearer_auth: headers["authorization"] = f"Bearer {resolved_bearer_auth}" return headers def create_sniff_callback( host_info_callback: Optional[ Callable[[Dict[str, Any], Dict[str, Any]], Optional[Dict[str, Any]]] ] = None, sniffed_node_callback: Optional[ Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]] ] = None, ) -> _TYPE_ASYNC_SNIFF_CALLBACK: assert (host_info_callback is None) != (sniffed_node_callback is None) # Wrap the deprecated 'host_info_callback' into 'sniffed_node_callback' if host_info_callback is not None: def _sniffed_node_callback( node_info: Dict[str, Any], node_config: NodeConfig ) -> Optional[NodeConfig]: assert host_info_callback is not None if ( host_info_callback( # type ignore[misc] node_info, {"host": node_config.host, "port": node_config.port} ) is None ): return None return node_config sniffed_node_callback = _sniffed_node_callback async def sniff_callback( transport: AsyncTransport, sniff_options: SniffOptions ) -> List[NodeConfig]: for _ in transport.node_pool.all(): try: meta, node_infos = await transport.perform_request( "GET", "/_nodes/_all/http", headers={ "accept": "application/vnd.elasticsearch+json; compatible-with=9" }, request_timeout=( sniff_options.sniff_timeout if not sniff_options.is_initial_sniff else None ), ) except (SerializationError, ConnectionError): continue if not 200 <= meta.status <= 299: continue node_configs = [] for node_info in node_infos.get("nodes", {}).values(): address = node_info.get("http", {}).get("publish_address") if not address or ":" not in address: continue if "/" in address: # Support 7.x host/ip:port behavior where http.publish_host has been set. fqdn, ipaddress = address.split("/", 1) host = fqdn _, port_str = ipaddress.rsplit(":", 1) port = int(port_str) else: host, port_str = address.rsplit(":", 1) port = int(port_str) assert sniffed_node_callback is not None sniffed_node = sniffed_node_callback( node_info, meta.node.replace(host=host, port=port) ) if sniffed_node is None: continue # Use the node which was able to make the request as a base. node_configs.append(sniffed_node) if node_configs: return node_configs return [] return sniff_callback def _default_sniffed_node_callback( node_info: Dict[str, Any], node_config: NodeConfig ) -> Optional[NodeConfig]: if node_info.get("roles", []) == ["master"]: return None return node_config default_sniff_callback = create_sniff_callback( sniffed_node_callback=_default_sniffed_node_callback ) class BaseClient: def __init__(self, _transport: AsyncTransport) -> None: self._transport = _transport self._client_meta: Union[DefaultType, Tuple[Tuple[str, str], ...]] = DEFAULT self._headers = HttpHeaders() self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT self._verified_elasticsearch = False self._otel = OpenTelemetry() @property def transport(self) -> AsyncTransport: return self._transport async def perform_request( self, method: str, path: str, *, params: Optional[Mapping[str, Any]] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Any] = None, endpoint_id: Optional[str] = None, path_parts: Optional[Mapping[str, Any]] = None, ) -> ApiResponse[Any]: with self._otel.span( method, endpoint_id=endpoint_id, path_parts=path_parts or {}, ) as otel_span: response = await self._perform_request( method, path, params=params, headers=headers, body=body, otel_span=otel_span, ) otel_span.set_elastic_cloud_metadata(response.meta.headers) return response async def _perform_request( self, method: str, path: str, *, params: Optional[Mapping[str, Any]] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Any] = None, otel_span: OpenTelemetrySpan, ) -> ApiResponse[Any]: if headers: request_headers = self._headers.copy() request_headers.update(headers) else: request_headers = self._headers def mimetype_header_to_compat(header: str) -> None: # Converts all parts of a Accept/Content-Type headers # from application/X -> application/vnd.elasticsearch+X mimetype = request_headers.get(header, None) if mimetype: request_headers[header] = _COMPAT_MIMETYPE_RE.sub( _COMPAT_MIMETYPE_SUB, mimetype ) mimetype_header_to_compat("Accept") mimetype_header_to_compat("Content-Type") if params: target = f"{path}?{_quote_query(params)}" else: target = path meta, resp_body = await self.transport.perform_request( method, target, headers=request_headers, body=body, request_timeout=self._request_timeout, max_retries=self._max_retries, retry_on_status=self._retry_on_status, retry_on_timeout=self._retry_on_timeout, client_meta=self._client_meta, otel_span=otel_span, ) # HEAD with a 404 is returned as a normal response # since this is used as an 'exists' functionality. if not (method == "HEAD" and meta.status == 404) and ( not 200 <= meta.status < 299 and ( self._ignore_status is DEFAULT or self._ignore_status is None or meta.status not in self._ignore_status ) ): message = str(resp_body) # If the response is an error response try parsing # the raw Elasticsearch error before raising. if isinstance(resp_body, dict): try: error = resp_body.get("error", message) if isinstance(error, dict) and "type" in error: error = error["type"] message = error except (ValueError, KeyError, TypeError): pass raise HTTP_EXCEPTIONS.get(meta.status, ApiError)( message=message, meta=meta, body=resp_body ) # 'X-Elastic-Product: Elasticsearch' should be on every 2XX response. if not self._verified_elasticsearch: # If the header is set we mark the server as verified. if meta.headers.get("x-elastic-product", "") == "Elasticsearch": self._verified_elasticsearch = True # Otherwise we only raise an error on 2XX responses. elif meta.status >= 200 and meta.status < 300: raise UnsupportedProductError( message=( "The client noticed that the server is not Elasticsearch " "and we do not support this unknown product" ), meta=meta, body=resp_body, ) # 'Warning' headers should be reraised as 'ElasticsearchWarning' if "warning" in meta.headers: warning_header = (meta.headers.get("warning") or "").strip() warning_messages: Iterable[str] = _WARNING_RE.findall(warning_header) or ( warning_header, ) stacklevel = warn_stacklevel() for warning_message in warning_messages: warnings.warn( warning_message, category=ElasticsearchWarning, stacklevel=stacklevel, ) if method == "HEAD": response = HeadApiResponse(meta=meta) elif isinstance(resp_body, dict): response = ObjectApiResponse(body=resp_body, meta=meta) # type: ignore[assignment] elif isinstance(resp_body, list): response = ListApiResponse(body=resp_body, meta=meta) # type: ignore[assignment] elif isinstance(resp_body, str): response = TextApiResponse( # type: ignore[assignment] body=resp_body, meta=meta, ) elif isinstance(resp_body, bytes): response = BinaryApiResponse(body=resp_body, meta=meta) # type: ignore[assignment] else: response = ApiResponse(body=resp_body, meta=meta) # type: ignore[assignment] return response class NamespacedClient(BaseClient): def __init__(self, client: "BaseClient") -> None: self._client = client super().__init__(self._client.transport) async def perform_request( self, method: str, path: str, *, params: Optional[Mapping[str, Any]] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Any] = None, endpoint_id: Optional[str] = None, path_parts: Optional[Mapping[str, Any]] = None, ) -> ApiResponse[Any]: # Use the internal clients .perform_request() implementation # so we take advantage of their transport options. return await self._client.perform_request( method, path, params=params, headers=headers, body=body, endpoint_id=endpoint_id, path_parts=path_parts, )