elastic_transport/client_utils.py (164 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import base64 import binascii import dataclasses import re import urllib.parse from platform import python_version from typing import Optional, Tuple, TypeVar, Union from urllib.parse import quote as _quote from urllib3.exceptions import LocationParseError from urllib3.util import parse_url from ._models import DEFAULT, DefaultType, NodeConfig from ._utils import fixup_module_metadata from ._version import __version__ __all__ = [ "CloudId", "DEFAULT", "DefaultType", "basic_auth_to_header", "client_meta_version", "create_user_agent", "dataclasses", "parse_cloud_id", "percent_encode", "resolve_default", "to_bytes", "to_str", "url_to_node_config", ] T = TypeVar("T") def resolve_default(val: Union[DefaultType, T], default: T) -> T: """Resolves a value that could be the ``DEFAULT`` sentinel into either the given value or the default value. """ return val if val is not DEFAULT else default def create_user_agent(name: str, version: str) -> str: """Creates the 'User-Agent' header given the library name and version""" return ( f"{name}/{version} (Python/{python_version()}; elastic-transport/{__version__})" ) def client_meta_version(version: str) -> str: """Converts a Python version into a version string compatible with the ``X-Elastic-Client-Meta`` HTTP header. """ match = re.match(r"^([0-9][0-9.]*[0-9]|[0-9])(.*)$", version) if match is None: raise ValueError( "Version {version!r} not formatted like a Python version string" ) version, version_suffix = match.groups() # Don't treat post-releases as pre-releases. if re.search(r"^\.post[0-9]*$", version_suffix): return version if version_suffix: version += "p" return version @dataclasses.dataclass(frozen=True, repr=True) class CloudId: #: Name of the cluster in Elastic Cloud cluster_name: str #: Host and port of the Elasticsearch instance es_address: Optional[Tuple[str, int]] #: Host and port of the Kibana instance kibana_address: Optional[Tuple[str, int]] def parse_cloud_id(cloud_id: str) -> CloudId: """Parses an Elastic Cloud ID into its components""" try: cloud_id = to_str(cloud_id) cluster_name, _, cloud_id = cloud_id.partition(":") parts = to_str(binascii.a2b_base64(to_bytes(cloud_id, "ascii")), "ascii").split( "$" ) parent_dn = parts[0] if not parent_dn: raise ValueError() # Caught and re-raised properly below es_uuid: Optional[str] kibana_uuid: Optional[str] try: es_uuid = parts[1] except IndexError: es_uuid = None try: kibana_uuid = parts[2] or None except IndexError: kibana_uuid = None if ":" in parent_dn: parent_dn, _, parent_port = parent_dn.rpartition(":") port = int(parent_port) else: port = 443 except (ValueError, IndexError, UnicodeError): raise ValueError("Cloud ID is not properly formatted") from None es_host = f"{es_uuid}.{parent_dn}" if es_uuid else None kibana_host = f"{kibana_uuid}.{parent_dn}" if kibana_uuid else None return CloudId( cluster_name=cluster_name, es_address=(es_host, port) if es_host else None, kibana_address=(kibana_host, port) if kibana_host else None, ) def to_str( value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict" ) -> str: if isinstance(value, bytes): return value.decode(encoding, errors) return value def to_bytes( value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict" ) -> bytes: if isinstance(value, str): return value.encode(encoding, errors) return value # Python 3.7 added '~' to the safe list for urllib.parse.quote() _QUOTE_ALWAYS_SAFE = frozenset( "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-~" ) def percent_encode( string: Union[bytes, str], safe: str = "/", encoding: Optional[str] = None, errors: Optional[str] = None, ) -> str: """Percent-encodes a string so it can be used in an HTTP request target""" # Redefines 'urllib.parse.quote()' to always have the '~' character # within the 'ALWAYS_SAFE' list. The character was added in Python 3.7 safe = "".join(_QUOTE_ALWAYS_SAFE.union(set(safe))) return _quote(string, safe, encoding=encoding, errors=errors) # type: ignore[arg-type] def basic_auth_to_header(basic_auth: Tuple[str, str]) -> str: """Converts a 2-tuple into a 'Basic' HTTP Authorization header""" if ( not isinstance(basic_auth, tuple) or len(basic_auth) != 2 or any(not isinstance(item, (str, bytes)) for item in basic_auth) ): raise ValueError( "'basic_auth' must be a 2-tuple of str/bytes (username, password)" ) return ( f"Basic {base64.b64encode(b':'.join(to_bytes(x) for x in basic_auth)).decode()}" ) def url_to_node_config( url: str, use_default_ports_for_scheme: bool = False ) -> NodeConfig: """Constructs a :class:`elastic_transport.NodeConfig` instance from a URL. If a username/password are specified in the URL they are converted to an 'Authorization' header. Always fills in a default port for HTTPS. :param url: URL to transform into a NodeConfig. :param use_default_ports_for_scheme: If 'True' will resolve default ports for HTTP. """ try: parsed_url = parse_url(url) except LocationParseError: raise ValueError(f"Could not parse URL {url!r}") from None parsed_port: Optional[int] = parsed_url.port if parsed_url.port is None and parsed_url.scheme is not None: # Always fill in a default port for HTTPS if parsed_url.scheme == "https": parsed_port = 443 # Only fill HTTP default port when asked to explicitly elif parsed_url.scheme == "http" and use_default_ports_for_scheme: parsed_port = 80 if any( component in (None, "") for component in (parsed_url.scheme, parsed_url.host, parsed_port) ): raise ValueError( "URL must include a 'scheme', 'host', and 'port' component (ie 'https://localhost:9200')" ) assert parsed_url.scheme is not None assert parsed_url.host is not None assert parsed_port is not None headers = {} if parsed_url.auth: # `urllib3.util.url_parse` ensures `parsed_url` is correctly # percent-encoded but does not percent-decode userinfo, so we have to # do it ourselves to build the basic auth header correctly. encoded_username, _, encoded_password = parsed_url.auth.partition(":") username = urllib.parse.unquote(encoded_username) password = urllib.parse.unquote(encoded_password) headers["authorization"] = basic_auth_to_header((username, password)) host = parsed_url.host.strip("[]") if not parsed_url.path or parsed_url.path == "/": path_prefix = "" else: path_prefix = parsed_url.path return NodeConfig( scheme=parsed_url.scheme, host=host, port=parsed_port, path_prefix=path_prefix, headers=headers, ) fixup_module_metadata(__name__, globals()) del fixup_module_metadata