elastic_transport/client_utils.py (164 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import base64
import binascii
import dataclasses
import re
import urllib.parse
from platform import python_version
from typing import Optional, Tuple, TypeVar, Union
from urllib.parse import quote as _quote
from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url
from ._models import DEFAULT, DefaultType, NodeConfig
from ._utils import fixup_module_metadata
from ._version import __version__
__all__ = [
"CloudId",
"DEFAULT",
"DefaultType",
"basic_auth_to_header",
"client_meta_version",
"create_user_agent",
"dataclasses",
"parse_cloud_id",
"percent_encode",
"resolve_default",
"to_bytes",
"to_str",
"url_to_node_config",
]
T = TypeVar("T")
def resolve_default(val: Union[DefaultType, T], default: T) -> T:
"""Resolves a value that could be the ``DEFAULT`` sentinel
into either the given value or the default value.
"""
return val if val is not DEFAULT else default
def create_user_agent(name: str, version: str) -> str:
"""Creates the 'User-Agent' header given the library name and version"""
return (
f"{name}/{version} (Python/{python_version()}; elastic-transport/{__version__})"
)
def client_meta_version(version: str) -> str:
"""Converts a Python version into a version string
compatible with the ``X-Elastic-Client-Meta`` HTTP header.
"""
match = re.match(r"^([0-9][0-9.]*[0-9]|[0-9])(.*)$", version)
if match is None:
raise ValueError(
"Version {version!r} not formatted like a Python version string"
)
version, version_suffix = match.groups()
# Don't treat post-releases as pre-releases.
if re.search(r"^\.post[0-9]*$", version_suffix):
return version
if version_suffix:
version += "p"
return version
@dataclasses.dataclass(frozen=True, repr=True)
class CloudId:
#: Name of the cluster in Elastic Cloud
cluster_name: str
#: Host and port of the Elasticsearch instance
es_address: Optional[Tuple[str, int]]
#: Host and port of the Kibana instance
kibana_address: Optional[Tuple[str, int]]
def parse_cloud_id(cloud_id: str) -> CloudId:
"""Parses an Elastic Cloud ID into its components"""
try:
cloud_id = to_str(cloud_id)
cluster_name, _, cloud_id = cloud_id.partition(":")
parts = to_str(binascii.a2b_base64(to_bytes(cloud_id, "ascii")), "ascii").split(
"$"
)
parent_dn = parts[0]
if not parent_dn:
raise ValueError() # Caught and re-raised properly below
es_uuid: Optional[str]
kibana_uuid: Optional[str]
try:
es_uuid = parts[1]
except IndexError:
es_uuid = None
try:
kibana_uuid = parts[2] or None
except IndexError:
kibana_uuid = None
if ":" in parent_dn:
parent_dn, _, parent_port = parent_dn.rpartition(":")
port = int(parent_port)
else:
port = 443
except (ValueError, IndexError, UnicodeError):
raise ValueError("Cloud ID is not properly formatted") from None
es_host = f"{es_uuid}.{parent_dn}" if es_uuid else None
kibana_host = f"{kibana_uuid}.{parent_dn}" if kibana_uuid else None
return CloudId(
cluster_name=cluster_name,
es_address=(es_host, port) if es_host else None,
kibana_address=(kibana_host, port) if kibana_host else None,
)
def to_str(
value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict"
) -> str:
if isinstance(value, bytes):
return value.decode(encoding, errors)
return value
def to_bytes(
value: Union[str, bytes], encoding: str = "utf-8", errors: str = "strict"
) -> bytes:
if isinstance(value, str):
return value.encode(encoding, errors)
return value
# Python 3.7 added '~' to the safe list for urllib.parse.quote()
_QUOTE_ALWAYS_SAFE = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-~"
)
def percent_encode(
string: Union[bytes, str],
safe: str = "/",
encoding: Optional[str] = None,
errors: Optional[str] = None,
) -> str:
"""Percent-encodes a string so it can be used in an HTTP request target"""
# Redefines 'urllib.parse.quote()' to always have the '~' character
# within the 'ALWAYS_SAFE' list. The character was added in Python 3.7
safe = "".join(_QUOTE_ALWAYS_SAFE.union(set(safe)))
return _quote(string, safe, encoding=encoding, errors=errors) # type: ignore[arg-type]
def basic_auth_to_header(basic_auth: Tuple[str, str]) -> str:
"""Converts a 2-tuple into a 'Basic' HTTP Authorization header"""
if (
not isinstance(basic_auth, tuple)
or len(basic_auth) != 2
or any(not isinstance(item, (str, bytes)) for item in basic_auth)
):
raise ValueError(
"'basic_auth' must be a 2-tuple of str/bytes (username, password)"
)
return (
f"Basic {base64.b64encode(b':'.join(to_bytes(x) for x in basic_auth)).decode()}"
)
def url_to_node_config(
url: str, use_default_ports_for_scheme: bool = False
) -> NodeConfig:
"""Constructs a :class:`elastic_transport.NodeConfig` instance from a URL.
If a username/password are specified in the URL they are converted to an
'Authorization' header. Always fills in a default port for HTTPS.
:param url: URL to transform into a NodeConfig.
:param use_default_ports_for_scheme: If 'True' will resolve default ports for HTTP.
"""
try:
parsed_url = parse_url(url)
except LocationParseError:
raise ValueError(f"Could not parse URL {url!r}") from None
parsed_port: Optional[int] = parsed_url.port
if parsed_url.port is None and parsed_url.scheme is not None:
# Always fill in a default port for HTTPS
if parsed_url.scheme == "https":
parsed_port = 443
# Only fill HTTP default port when asked to explicitly
elif parsed_url.scheme == "http" and use_default_ports_for_scheme:
parsed_port = 80
if any(
component in (None, "")
for component in (parsed_url.scheme, parsed_url.host, parsed_port)
):
raise ValueError(
"URL must include a 'scheme', 'host', and 'port' component (ie 'https://localhost:9200')"
)
assert parsed_url.scheme is not None
assert parsed_url.host is not None
assert parsed_port is not None
headers = {}
if parsed_url.auth:
# `urllib3.util.url_parse` ensures `parsed_url` is correctly
# percent-encoded but does not percent-decode userinfo, so we have to
# do it ourselves to build the basic auth header correctly.
encoded_username, _, encoded_password = parsed_url.auth.partition(":")
username = urllib.parse.unquote(encoded_username)
password = urllib.parse.unquote(encoded_password)
headers["authorization"] = basic_auth_to_header((username, password))
host = parsed_url.host.strip("[]")
if not parsed_url.path or parsed_url.path == "/":
path_prefix = ""
else:
path_prefix = parsed_url.path
return NodeConfig(
scheme=parsed_url.scheme,
host=host,
port=parsed_port,
path_prefix=path_prefix,
headers=headers,
)
fixup_module_metadata(__name__, globals())
del fixup_module_metadata