# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import contextlib
import logging
import time

import certifi
from urllib3.connection import is_ipaddress

from esrally import doc_link, exceptions
from esrally.utils import console, convert, versions


class EsClientFactory:
    """
    Abstracts how the Elasticsearch client is created and customizes the client for backwards
    compatibility guarantees that are broader than the library's defaults.
    """

    def __init__(self, hosts, client_options, distribution_version=None, distribution_flavor=None):
        def host_string(host):
            # protocol can be set at either host or client opts level
            protocol = "https" if client_options.get("use_ssl") or host.get("use_ssl") else "http"
            return f"{protocol}://{host['host']}:{host['port']}"

        self.hosts = [host_string(h) for h in hosts]
        self.client_options = dict(client_options)
        self.ssl_context = None
        # This attribute is necessary for the backwards-compatibility logic contained in
        # RallySyncElasticsearch.perform_request() and RallyAsyncElasticsearch.perform_request(), and also for
        # identification of whether or not a client is 'serverless'.
        self.distribution_version = distribution_version
        self.distribution_flavor = distribution_flavor
        self.logger = logging.getLogger(__name__)

        masked_client_options = dict(client_options)
        if "basic_auth_password" in masked_client_options:
            masked_client_options["basic_auth_password"] = "*****"
        if "http_auth" in masked_client_options:
            masked_client_options["http_auth"] = (masked_client_options["http_auth"][0], "*****")
        if "api_key" in masked_client_options:
            masked_client_options["api_key"] = "*****"
        self.logger.info("Creating ES client connected to %s with options [%s]", hosts, masked_client_options)

        # we're using an SSL context now and it is not allowed to have use_ssl present in client options anymore
        if self.client_options.pop("use_ssl", False):
            # pylint: disable=import-outside-toplevel
            import ssl

            self.logger.debug("SSL support: on")

            self.ssl_context = ssl.create_default_context(
                ssl.Purpose.SERVER_AUTH, cafile=self.client_options.pop("ca_certs", certifi.where())
            )

            # We call get() here instead of pop() in order to pass verify_certs through as a kwarg
            # to the elasticsearch.Elasticsearch constructor. Setting the ssl_context's verify_mode to
            # ssl.CERT_NONE is insufficient with version 8.0+ of elasticsearch-py.
            if not self.client_options.get("verify_certs", True):
                self.logger.debug("SSL certificate verification: off")
                # order matters to avoid ValueError: check_hostname needs a SSL context with either CERT_OPTIONAL or CERT_REQUIRED
                self.ssl_context.check_hostname = False
                self.ssl_context.verify_mode = ssl.CERT_NONE
                self.client_options["ssl_show_warn"] = False

                self.logger.debug(
                    "User has enabled SSL but disabled certificate verification. This is dangerous but may be ok for a benchmark."
                )
            else:
                # check_hostname should not be set when host is an IP address
                self.ssl_context.check_hostname = self._only_hostnames(hosts)
                self.ssl_context.verify_mode = ssl.CERT_REQUIRED
                self.logger.debug("SSL certificate verification: on")

            # When using SSL_context, all SSL related kwargs in client options get ignored
            client_cert = self.client_options.pop("client_cert", False)
            client_key = self.client_options.pop("client_key", False)

            if not client_cert and not client_key:
                self.logger.debug("SSL client authentication: off")
            elif bool(client_cert) != bool(client_key):
                self.logger.error("Supplied client-options contain only one of client_cert/client_key. ")
                defined_client_ssl_option = "client_key" if client_key else "client_cert"
                missing_client_ssl_option = "client_cert" if client_key else "client_key"
                console.println(
                    "'{}' is missing from client-options but '{}' has been specified.\n"
                    "If your Elasticsearch setup requires client certificate verification both need to be supplied.\n"
                    "Read the documentation at {}\n".format(
                        missing_client_ssl_option,
                        defined_client_ssl_option,
                        console.format.link(doc_link("command_line_reference.html#client-options")),
                    )
                )
                raise exceptions.SystemSetupError(
                    "Cannot specify '{}' without also specifying '{}' in client-options.".format(
                        defined_client_ssl_option, missing_client_ssl_option
                    )
                )
            elif client_cert and client_key:
                self.logger.debug("SSL client authentication: on")
                self.ssl_context.load_cert_chain(certfile=client_cert, keyfile=client_key)
        else:
            self.logger.debug("SSL support: off")

        if self._is_set(self.client_options, "create_api_key_per_client"):
            self.client_options.pop("create_api_key_per_client")
            basic_auth_user = self.client_options.get("basic_auth_user", False)
            basic_auth_password = self.client_options.get("basic_auth_password", False)
            provided_auth = {"basic_auth_user": basic_auth_user, "basic_auth_password": basic_auth_password}
            missing_auth = [k for k, v in provided_auth.items() if not v]
            if missing_auth:
                console.println(
                    "Basic auth credentials are required in order to create API keys.\n"
                    f"Missing basic auth client options are: {missing_auth}\n"
                    f"Read the documentation at {console.format.link(doc_link('command_line_reference.html#client-options'))}"
                )
                raise exceptions.SystemSetupError(
                    "You must provide the 'basic_auth_user' and 'basic_auth_password' client options in addition "
                    "to 'create_api_key_per_client' in order to create client API keys."
                )
            self.logger.debug("Automatic creation of client API keys: on")
        else:
            self.logger.debug("Automatic creation of client API keys: off")

        if self._is_set(self.client_options, "basic_auth_user") and self._is_set(self.client_options, "basic_auth_password"):
            self.client_options["basic_auth"] = (self.client_options.pop("basic_auth_user"), self.client_options.pop("basic_auth_password"))
            self.logger.debug("HTTP basic authentication: on")
        else:
            self.logger.debug("HTTP basic authentication: off")

        if self._is_set(self.client_options, "api_key"):
            self.logger.debug("API key authentication: on")
        else:
            self.logger.debug("API key authentication: off")

        if self._is_set(self.client_options, "compressed"):
            console.warn("You set the deprecated client option 'compressed‘. Please use 'http_compress' instead.", logger=self.logger)
            self.client_options["http_compress"] = self.client_options.pop("compressed")

        if self._is_set(self.client_options, "http_compress"):
            self.logger.debug("HTTP compression: on")
        else:
            self.logger.debug("HTTP compression: off")

        self.enable_cleanup_closed = convert.to_bool(self.client_options.pop("enable_cleanup_closed", True))
        self.max_connections = max(256, self.client_options.pop("max_connections", 0))
        self.static_responses = self.client_options.pop("static_responses", None)

        if self._is_set(self.client_options, "timeout"):
            self.client_options["request_timeout"] = self.client_options.pop("timeout")

    @staticmethod
    def _only_hostnames(hosts):
        has_ip = False
        has_hostname = False
        for host in hosts:
            is_ip = is_ipaddress(host["host"])
            if is_ip:
                has_ip = True
            else:
                has_hostname = True

        if has_ip and has_hostname:
            raise exceptions.SystemSetupError("Cannot verify certs with mixed IP addresses and hostnames")

        return has_hostname

    def _is_set(self, client_opts, k):
        try:
            return client_opts[k]
        except KeyError:
            return False

    def create(self):
        # pylint: disable=import-outside-toplevel
        from esrally.client.synchronous import RallySyncElasticsearch

        return RallySyncElasticsearch(
            distribution_version=self.distribution_version,
            distribution_flavor=self.distribution_flavor,
            hosts=self.hosts,
            ssl_context=self.ssl_context,
            **self.client_options,
        )

    def create_async(self, api_key=None, client_id=None):
        # pylint: disable=import-outside-toplevel
        import io

        import aiohttp
        from elasticsearch.serializer import JSONSerializer

        from esrally.client.asynchronous import (
            RallyAsyncElasticsearch,
            RallyAsyncTransport,
        )

        class LazyJSONSerializer(JSONSerializer):
            def loads(self, data):
                meta = RallyAsyncElasticsearch.request_context.get()
                if "raw_response" in meta:
                    return io.BytesIO(data)
                else:
                    return super().loads(data)

        async def on_request_start(session, trace_config_ctx, params):
            RallyAsyncElasticsearch.on_request_start()

        async def on_request_end(session, trace_config_ctx, params):
            RallyAsyncElasticsearch.on_request_end()

        trace_config = aiohttp.TraceConfig()
        trace_config.on_request_start.append(on_request_start)
        # It is tempting to register this callback on `TraceConfig.on_request_end()`. However, aiohttp will call
        # `TraceConfig.on_request_end()` when the *first* chunk of the response has been received. However, this can
        # skew service time significantly if the response is large *and* it is streamed by Elasticsearch
        # (see ChunkedToXContent in the Elasticsearch code base).
        #
        # Therefore, we register for `TraceConfig.on_response_chunk_received()` which is called multiple times. As
        # Rally's implementation of the `on_request_end` callback handler updates the timestamp on every call, Rally
        # will ultimately record the time when it received the *last* chunk. This is what we want because any code
        # that is using the Elasticsearch client library can only act on the response once it is fully received.
        #
        # We also keep registering for `TraceConfig.on_request_end()` instead of relying on
        # `TraceConfig.on_response_chunk_received()` only to handle corner cases during client timeout when aiohttp does
        # not call request exception handler, but does call request end handler. See
        # https://github.com/elastic/rally/issues/1860 for details.
        trace_config.on_response_chunk_received.append(on_request_end)
        trace_config.on_request_end.append(on_request_end)
        # ensure that we also stop the timer when a request "ends" with an exception (e.g. a timeout)
        trace_config.on_request_exception.append(on_request_end)

        # override the builtin JSON serializer
        self.client_options["serializer"] = LazyJSONSerializer()

        if api_key is not None:
            self.client_options.pop("http_auth", None)
            self.client_options.pop("basic_auth", None)
            self.client_options["api_key"] = api_key

        async_client = RallyAsyncElasticsearch(
            distribution_version=self.distribution_version,
            distribution_flavor=self.distribution_flavor,
            hosts=self.hosts,
            transport_class=RallyAsyncTransport,
            ssl_context=self.ssl_context,
            maxsize=self.max_connections,
            **self.client_options,
        )

        # the AsyncElasticsearch constructor automatically creates the corresponding NodeConfig objects, so we set
        # their instance attributes after they've been instantiated
        for node_connection in async_client.transport.node_pool.all():
            node_connection.trace_configs = [trace_config]
            node_connection.enable_cleanup_closed = self.enable_cleanup_closed
            node_connection.static_responses = self.static_responses
            node_connection.client_id = client_id

        return async_client


def wait_for_rest_layer(es, max_attempts=40):
    """
    Waits for ``max_attempts`` until Elasticsearch's REST API is available.

    :param es: Elasticsearch client to use for connecting.
    :param max_attempts: The maximum number of attempts to check whether the REST API is available.
    :return: True iff Elasticsearch's REST API is available.
    """
    # assume that at least the hosts that we expect to contact should be available. Note that this is not 100%
    # bullet-proof as a cluster could have e.g. dedicated masters which are not contained in our list of target hosts
    # but this is still better than just checking for any random node's REST API being reachable.
    expected_node_count = len(es.transport.node_pool)
    logger = logging.getLogger(__name__)
    attempt = 0
    while attempt <= max_attempts:
        attempt += 1
        # pylint: disable=import-outside-toplevel
        from elastic_transport import (
            ApiError,
            ConnectionError,
            SerializationError,
            TlsError,
            TransportError,
        )

        try:
            # see also WaitForHttpResource in Elasticsearch tests. Contrary to the ES tests we consider the API also
            # available when the cluster status is RED (as long as all required nodes are present)
            es.cluster.health(wait_for_nodes=f">={expected_node_count}")
            logger.debug("REST API is available for >= [%s] nodes after [%s] attempts.", expected_node_count, attempt)
            return True
        except SerializationError as e:
            if "Client sent an HTTP request to an HTTPS server" in str(e):
                raise exceptions.SystemSetupError(
                    "Rally sent an HTTP request to an HTTPS server. Are you sure this is an HTTP endpoint?", e
                )

            if attempt <= max_attempts:
                logger.debug("Got serialization error [%s] on attempt [%s]. Sleeping...", e, attempt)
                time.sleep(3)
            else:
                raise
        except TlsError as e:
            raise exceptions.SystemSetupError("Could not connect to cluster via HTTPS. Are you sure this is an HTTPS endpoint?", e)
        except ConnectionError as e:
            if "ProtocolError" in str(e):
                raise exceptions.SystemSetupError(
                    "Received a protocol error. Are you sure you're using the correct scheme (HTTP or HTTPS)?", e
                )

            if attempt <= max_attempts:
                logger.debug("Got connection error on attempt [%s]. Sleeping...", attempt)
                time.sleep(3)
            else:
                raise
        except TransportError as e:
            if attempt <= max_attempts:
                logger.debug("Got transport error on attempt [%s]. Sleeping...", attempt)
                time.sleep(3)
            else:
                raise
        except ApiError as e:
            # cluster block, x-pack not initialized yet, our wait condition is not reached
            if e.status_code in (503, 401, 408) and attempt <= max_attempts:
                logger.debug("Got status code [%s] on attempt [%s]. Sleeping...", e.message, attempt)
                time.sleep(3)
            else:
                logger.warning("Got unexpected status code [%s] on attempt [%s].", e.message, attempt)
                raise
    return False


def cluster_distribution_version(hosts, client_options, client_factory=EsClientFactory):
    """
    Attempt to get the target cluster's distribution version, build flavor, and build hash by creating and using
    a 'sync' Elasticsearch client.

    :param hosts: The host(s) to connect to.
    :param client_options: The client options to customize the Elasticsearch client.
    :param client_factory: Factory class that creates the Elasticsearch client.
    :return: The cluster's build flavor, version number, and build hash. For Serverless Elasticsearch these may all be
      the build flavor value. Also returns the operator status (always False for stateful).
    """
    # no way for us to know whether we're talking to a serverless elasticsearch or not, so we default to the sync client
    es = client_factory(hosts, client_options).create()

    # wait_for_rest_layer  calls the Cluster Health API, which is not available for unprivileged users on Serverless
    # As a result, we need to call the info API first to know if we can call wait_for_rest_layer().
    version = es.info()["version"]

    version_build_flavor = version.get("build_flavor", "oss")
    # if build hash is not available default to build flavor
    version_build_hash = version.get("build_hash", version_build_flavor)
    # if version number is not available default to build flavor
    version_number = version.get("number", version_build_flavor)

    # assume non-operator serverless privileges by default
    serverless_operator = False
    if versions.is_serverless(version_build_flavor):
        # overwrite static serverless version number
        version_number = "serverless"

        # determine serverless operator status if security enabled
        # pylint: disable=import-outside-toplevel
        from elasticsearch.exceptions import ApiError

        with contextlib.suppress(ApiError):
            authentication_info = es.perform_request(method="GET", path="/_security/_authenticate")
            serverless_operator = authentication_info.body.get("operator", False)

    if not versions.is_serverless(version_build_flavor) or serverless_operator is True:
        # if available, unconditionally wait for the REST layer - if it's not up, we'll intentionally raise the original error
        wait_for_rest_layer(es)

    return version_build_flavor, version_number, version_build_hash, serverless_operator


def create_api_key(es, client_id, max_attempts=5):
    """
    Creates an API key for the provided ``client_id``.

    :param es: Elasticsearch client to use for connecting.
    :param client_id: ID of the client for which the API key is being created.
    :param max_attempts: The maximum number of attempts to create the API key.
    :return: A dict with at least the following keys: ``id``, ``name``, ``api_key``.
    """
    logger = logging.getLogger(__name__)

    for attempt in range(1, max_attempts + 1):
        # pylint: disable=import-outside-toplevel
        import elasticsearch

        try:
            logger.debug("Creating ES API key for client ID [%s]", client_id)
            return es.security.create_api_key(name=f"rally-client-{client_id}")
        except elasticsearch.TransportError as e:
            logger.debug("Got transport error [%s] on attempt [%s]. Sleeping...", str(e), attempt)
            time.sleep(1)
        except elasticsearch.ApiError as e:
            if e.meta.status == 405:
                # We don't retry on 405 since it indicates a misconfigured benchmark candidate and isn't recoverable
                raise exceptions.SystemSetupError(
                    "Got status code 405 when attempting to create API keys. Is Elasticsearch Security enabled?", e
                )
            logger.debug("Got status code [%s] on attempt [%s] of [%s]. Sleeping...", e.status_code, attempt, max_attempts)
            time.sleep(1)


def delete_api_keys(es, ids, max_attempts=5):
    """
    Deletes the provided list of API key IDs.

    :param es: Elasticsearch client to use for connecting.
    :param ids: List of API key IDs to delete.
    :param max_attempts: The maximum number of attempts to delete the API keys.
    :return: True iff all provided key IDs were successfully deleted.
    """
    logger = logging.getLogger(__name__)

    def raise_exception(failed_ids, cause=None):
        msg = f"Could not delete API keys with the following IDs: {failed_ids}"
        if cause is not None:
            raise exceptions.RallyError(msg) from cause
        raise exceptions.RallyError(msg)

    # Before ES 7.10, deleting API keys by ID had to be done individually.
    # After ES 7.10, a list of API key IDs can be deleted in one request.
    version = es.info()["version"]
    current_version = versions.Version.from_string(version.get("number", "7.10.0"))
    minimum_version = versions.Version.from_string("7.10.0")

    deleted = []
    remaining = ids

    for attempt in range(1, max_attempts + 1):
        # pylint: disable=import-outside-toplevel
        import elasticsearch

        try:
            if current_version >= minimum_version or es.is_serverless:
                resp = es.security.invalidate_api_key(ids=remaining)
                deleted += resp["invalidated_api_keys"]
                remaining = [i for i in ids if i not in deleted]
                # Like bulk indexing requests, we can get an HTTP 200, but the
                # response body could still contain an array of individual errors.
                # So, we have to handle the case were some keys weren't deleted, but
                # the request overall succeeded (i.e. we didn't encounter an exception)
                if attempt < max_attempts:
                    if resp["error_count"] > 0:
                        logger.debug(
                            "Got the following errors on attempt [%s] of [%s]: [%s]. Sleeping...",
                            attempt,
                            max_attempts,
                            resp["error_details"],
                        )
                else:
                    if remaining:
                        logger.warning(
                            "Got the following errors on final attempt to delete API keys: [%s]",
                            resp["error_details"],
                        )
                        raise_exception(remaining)
            else:
                remaining = [i for i in ids if i not in deleted]
                if attempt < max_attempts:
                    for i in remaining:
                        es.security.invalidate_api_key(id=i)
                        deleted.append(i)
                else:
                    if remaining:
                        raise_exception(remaining)
            return True

        except elasticsearch.ApiError as e:
            if attempt < max_attempts:
                logger.debug("Got status code [%s] on attempt [%s] of [%s]. Sleeping...", e.meta.status, attempt, max_attempts)
                time.sleep(1)
            else:
                raise_exception(remaining, cause=e)
        except Exception as e:
            if attempt < max_attempts:
                logger.debug("Got error on attempt [%s] of [%s]. Sleeping...", attempt, max_attempts)
                time.sleep(1)
            else:
                raise_exception(remaining, cause=e)
