elastic_enterprise_search/_utils.py (340 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import base64
import inspect
import re
import sys
import typing as t
import warnings
from datetime import date, datetime
from functools import wraps
from pathlib import Path
from dateutil import parser, tz
from elastic_transport import HttpHeaders, NodeConfig
from elastic_transport.client_utils import (
DEFAULT,
DefaultType,
client_meta_version,
create_user_agent,
percent_encode,
url_to_node_config,
)
from ._version import __version__
__all__ = [
"DEFAULT",
"SKIP_IN_PATH",
"format_datetime",
"parse_datetime",
"resolve_auth_headers",
]
F = t.TypeVar("F")
SKIP_IN_PATH = {None, "", b""}
CLIENT_META_SERVICE = ("ent", client_meta_version(__version__))
USER_AGENT = create_user_agent("enterprise-search-python", __version__)
_TRANSPORT_OPTIONS = {
"http_auth",
"request_timeout",
"opaque_id",
"headers",
"ignore_status",
}
def format_datetime(value):
# type: (datetime) -> str
"""Format a datetime object to RFC 3339"""
# When given a timezone unaware datetime, use local timezone.
if value.tzinfo is None:
value = value.replace(tzinfo=tz.tzlocal())
utcoffset = value.utcoffset()
offset_secs = utcoffset.total_seconds()
# Use 'Z' for UTC, otherwise use '[+-]XX:XX' for tz offset
if offset_secs == 0:
timezone = "Z"
else:
offset_sign = "+" if offset_secs >= 0 else "-"
offset_secs = int(abs(offset_secs))
hours = offset_secs // 3600
minutes = (offset_secs % 3600) // 60
timezone = f"{offset_sign}{hours:02}:{minutes:02}"
return value.strftime("%Y-%m-%dT%H:%M:%S") + timezone
def parse_datetime(value):
# type: (str) -> datetime
"""Convert a string value RFC 3339 into a datetime with tzinfo"""
if not re.match(
r"^[0-9]{4}-[0-9]{2}-[0-9]{2}[T ][0-9]{2}:[0-9]{2}:[0-9]{2}(?:Z|[+\-][0-9]{2}:[0-9]{2})$",
value,
):
raise ValueError(
"Datetime must match format '(YYYY)-(MM)-(DD)T(HH):(MM):(SS)(TZ)' was '%s'"
% value
)
return parser.isoparse(value)
def resolve_auth_headers(
*,
headers: t.Optional[t.Mapping[str, str]],
http_auth: t.Union[DefaultType, None, t.Tuple[str, str], str] = DEFAULT,
basic_auth: t.Union[DefaultType, None, t.Tuple[str, str], str] = DEFAULT,
bearer_auth: t.Union[DefaultType, None, str] = DEFAULT,
) -> HttpHeaders:
# Return an HttpHeaders instance. Can return values with 'None'
# which are meant to be popped from an existing HttpHeaders instance
# if two instances are being combined together.
if headers is None or headers is DEFAULT:
headers = HttpHeaders()
elif not isinstance(headers, HttpHeaders):
headers = HttpHeaders(headers)
# Handle special authentication options
auth_params_given = sum(
x is not DEFAULT for x in (http_auth, basic_auth, bearer_auth)
)
if auth_params_given > 1:
# More than one authentication parameter would have conflicts.
# Don't allow users to specify both.
raise ValueError(
"Can't specify more than one authentication parameter (basic_auth/bearer_auth)"
)
# If there's exactly one parameter specified we can apply it to headers.
if auth_params_given == 1:
# http_auth -> bearer_auth / basic_auth
if http_auth is not DEFAULT:
if http_auth is None:
basic_auth = None
elif isinstance(http_auth, (list, tuple)) and all(
isinstance(x, str) for x in http_auth
):
basic_auth = http_auth
elif isinstance(http_auth, str):
bearer_auth = http_auth
else:
raise TypeError(
"The deprecated 'http_auth' parameter must be either 'Tuple[str, str]' or 'str'. "
"Use the 'basic_auth' or 'bearer_auth' parameters instead"
)
warnings.warn(
"The 'http_auth' parameter is deprecated. "
"Use 'basic_auth' or 'bearer_auth' parameters instead",
category=DeprecationWarning,
stacklevel=warn_stacklevel(),
)
# Setting the 'Authorization' header would be a conflict.
if headers and headers.get("authorization", None) is not None:
raise ValueError(
"Can't set 'Authorization' HTTP header with other authentication options"
)
# Basic auth
if isinstance(basic_auth, str):
headers["authorization"] = f"Basic {basic_auth}"
elif isinstance(basic_auth, (list, tuple)):
headers["authorization"] = (
f"Basic {base64.b64encode(':'.join(basic_auth).encode('utf-8')).decode('ascii')}"
)
elif basic_auth is None:
headers["authorization"] = None
elif basic_auth is not DEFAULT:
raise TypeError(
"'basic_auth' must be a string or 2 item list/tuple of strings"
)
# Bearer auth
if isinstance(bearer_auth, str):
headers["authorization"] = f"Bearer {bearer_auth}"
elif bearer_auth is None:
headers["authorization"] = None
elif bearer_auth is not DEFAULT:
raise TypeError("'bearer_auth' must be a string")
return headers
def client_node_configs(hosts, **kwargs) -> t.List[NodeConfig]:
if hosts is None or hosts is DEFAULT:
hosts = ["http://localhost:3002"]
if not isinstance(hosts, (list, tuple)):
hosts = [hosts]
node_configs = []
for host in hosts:
if isinstance(host, str):
node_configs.append(
url_to_node_config(host, use_default_ports_for_scheme=True)
)
else:
raise TypeError("URLs must be of type 'str'")
# Remove all values which are 'DEFAULT' to avoid overwriting actual defaults.
node_options = {k: v for k, v in kwargs.items() if v is not DEFAULT}
# Set the 'User-Agent' default header.
headers = HttpHeaders(node_options.pop("headers", ()))
headers.setdefault("user-agent", USER_AGENT)
node_options["headers"] = headers
def apply_node_options(node_config: NodeConfig) -> NodeConfig:
"""Needs special handling of headers since .replace() wipes out existing headers"""
nonlocal node_options
headers = node_config.headers.copy() # type: ignore[attr-defined]
headers_to_add = node_options.pop("headers", ())
if headers_to_add:
headers.update(headers_to_add)
headers.setdefault("user-agent", USER_AGENT)
headers.freeze()
node_options["headers"] = headers
return node_config.replace(**node_options)
return [apply_node_options(node_config) for node_config in node_configs]
def warn_stacklevel() -> int:
"""Dynamically determine warning stacklevel for warnings based on the call stack"""
try:
# Grab the root module from the current module '__name__'
module_name = __name__.partition(".")[0]
module_path = Path(sys.modules[module_name].__file__) # type: ignore[arg-type]
# If the module is a folder we're looking at
# subdirectories, otherwise we're looking for
# an exact match.
module_is_folder = module_path.name == "__init__.py"
if module_is_folder:
module_path = module_path.parent
# Look through frames until we find a file that
# isn't a part of our module, then return that stacklevel.
for level, frame in enumerate(inspect.stack()):
# Garbage collecting frames
frame_filename = Path(frame.filename)
del frame
if (
# If the module is a folder we look at subdirectory
module_is_folder
and module_path not in frame_filename.parents
) or (
# Otherwise we're looking for an exact match.
not module_is_folder
and module_path != frame_filename
):
return level
except KeyError:
pass
return 0
def _escape(value: t.Any) -> str:
"""Escape a value into a string"""
if isinstance(value, date):
return value.isoformat()
elif isinstance(value, datetime):
return format_datetime(value)
elif isinstance(value, bytes):
return value.decode("utf-8", "surrogatepass")
if not isinstance(value, str):
return str(value)
return value
def _quote(value: t.Any) -> str:
"""Percent-encode a value according to values that Enterprise Search accepts un-encoded"""
return percent_encode(_escape(value), ",*[]:-")
def _quote_query(
query: t.Union[t.Mapping[str, t.Any], t.Iterable[t.Tuple[str, t.Any]]]
) -> str:
"""Quote an iterable or mapping of key-value pairs into a querystring"""
unquoted_kvs = query.items() if hasattr(query, "items") else query
kvs: t.List[t.Tuple[str, str]] = []
for k, v in unquoted_kvs:
if isinstance(v, (list, tuple, dict)):
if k.endswith("[]"):
k = k[:-2]
kvs.extend(_quote_query_deep_object(k, v))
else:
kvs.append((k, _quote(v)))
return "&".join([f"{k}={v}" for k, v in kvs])
def _quote_query_deep_object(
prefix: str, value: t.Any
) -> t.Iterable[t.Tuple[str, str]]:
"""Quote a list or mapping object into a querystring"""
if not isinstance(value, (list, tuple, dict)):
yield (prefix, _quote(value))
elif isinstance(value, (list, tuple)):
for item in value:
yield from _quote_query_deep_object(f"{prefix}[]", item)
else:
for key, val in value.items():
yield from _quote_query_deep_object(f"{prefix}[{key}]", val)
def _quote_query_form(key: str, value: t.Union[t.List[str], t.Tuple[str, ...]]) -> str:
if not isinstance(value, (tuple, list)):
raise ValueError(f"{key!r} must be of type list or tuple")
return ",".join(map(str, value))
def _merge_kwargs_no_duplicates(
kwargs: t.Dict[str, t.Any], values: t.Dict[str, t.Any]
) -> None:
for key, val in values.items():
if key in kwargs:
raise ValueError(
f"Received multiple values for '{key}', specify parameters "
"directly instead of using 'body' or 'params'"
)
kwargs[key] = val
def _rewrite_parameters(
body_name: t.Optional[str] = None,
body_fields: bool = False,
parameter_aliases: t.Optional[t.Dict[str, str]] = None,
ignore_deprecated_options: t.Optional[t.Set[str]] = None,
) -> t.Callable[[F], F]:
def wrapper(api: F) -> F:
@wraps(api)
def wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
nonlocal api, body_name, body_fields
# Let's give a nicer error message when users pass positional arguments.
if len(args) >= 2:
raise TypeError(
"Positional arguments can't be used with client API methods. "
"Instead only use keyword arguments."
)
# We merge 'params' first as transport options can be specified using params.
if "params" in kwargs and (
not ignore_deprecated_options
or "params" not in ignore_deprecated_options
):
params = kwargs.pop("params")
if params:
if not hasattr(params, "items"):
raise ValueError(
"Couldn't merge 'params' with other parameters as it wasn't a mapping. "
"Instead of using 'params' use individual API parameters"
)
warnings.warn(
"The 'params' parameter is deprecated and will be removed "
"in a future version. Instead use individual parameters.",
category=DeprecationWarning,
stacklevel=warn_stacklevel(),
)
_merge_kwargs_no_duplicates(kwargs, params)
maybe_transport_options = _TRANSPORT_OPTIONS.intersection(kwargs)
if maybe_transport_options:
transport_options = {}
for option in maybe_transport_options:
if (
ignore_deprecated_options
and option in ignore_deprecated_options
):
continue
# 'http_auth' needs to be aliased to 'basic_auth' or 'bearer_auth'.
transport_option = option
if option == "http_auth":
if isinstance(kwargs["http_auth"], str):
transport_option = "bearer_auth"
elif (
isinstance(kwargs["http_auth"], (list, tuple))
and len(kwargs["http_auth"]) == 2
):
transport_option = "basic_auth"
else:
raise TypeError(
"'http_auth' must be either a str or a 2-tuple of strings"
)
try:
transport_options[transport_option] = kwargs.pop(option)
except KeyError:
pass
if transport_options:
client = args[0].options(**transport_options)
warnings.warn(
"Passing transport options in the API method is deprecated. "
f"Use '{type(client).__name__}.options()' instead.",
category=DeprecationWarning,
stacklevel=warn_stacklevel(),
)
args = (client,) + args[1:]
if "body" in kwargs and (
not ignore_deprecated_options or "body" not in ignore_deprecated_options
):
body = kwargs.pop("body")
if body is not None:
if body_name:
if body_name in kwargs:
raise TypeError(
f"Can't use '{body_name}' and 'body' parameters together because '{body_name}' "
"is an alias for 'body'. Instead you should only use the "
f"'{body_name}' parameter."
)
warnings.warn(
"The 'body' parameter is deprecated and will be removed "
f"in a future version. Instead use the '{body_name}' parameter.",
category=DeprecationWarning,
stacklevel=warn_stacklevel(),
)
kwargs[body_name] = body
elif body_fields:
if not hasattr(body, "items"):
raise ValueError(
"Couldn't merge 'body' with other parameters as it wasn't a mapping. "
"Instead of using 'body' use individual API parameters"
)
warnings.warn(
"The 'body' parameter is deprecated and will be removed "
"in a future version. Instead use individual parameters.",
category=DeprecationWarning,
stacklevel=warn_stacklevel(),
)
# Special handling of page:{current:1,size:20} -> current_page=1, page_size=20
if "page" in body and set(body["page"]).issubset(
{"current", "size"}
):
page = body.pop("page")
if "current" in page:
kwargs["current_page"] = page["current"]
if "size" in page:
kwargs["page_size"] = page["size"]
_merge_kwargs_no_duplicates(kwargs, body)
if parameter_aliases:
for alias, rename_to in parameter_aliases.items():
try:
kwargs[rename_to] = kwargs.pop(alias)
except KeyError:
pass
return api(*args, **kwargs)
return wrapped # type: ignore[return-value]
return wrapper