# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import asyncio
import json
import logging
import warnings
from collections.abc import Iterable, Mapping
from typing import Any, Optional

import aiohttp
from aiohttp import BaseConnector, RequestInfo
from aiohttp.client_proto import ResponseHandler
from aiohttp.helpers import BaseTimerContext
from elastic_transport import (
    AiohttpHttpNode,
    ApiResponse,
    AsyncTransport,
    BinaryApiResponse,
    HeadApiResponse,
    ListApiResponse,
    ObjectApiResponse,
    TextApiResponse,
)
from elastic_transport.client_utils import DEFAULT
from elasticsearch import AsyncElasticsearch
from elasticsearch._async.client import IlmClient
from elasticsearch.compat import warn_stacklevel
from elasticsearch.exceptions import HTTP_EXCEPTIONS, ApiError, ElasticsearchWarning
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL

from esrally.client.common import _WARNING_RE, _mimetype_header_to_compat, _quote_query
from esrally.client.context import RequestContextHolder
from esrally.utils import io, versions


class StaticTransport:
    def __init__(self):
        self.closed = False

    def is_closing(self):
        return False

    def close(self):
        self.closed = True

    def abort(self):
        self.close()


class StaticConnector(BaseConnector):
    async def _create_connection(self, req: "ClientRequest", traces: list["Trace"], timeout: "ClientTimeout") -> ResponseHandler:
        handler = ResponseHandler(self._loop)
        handler.transport = StaticTransport()
        handler.protocol = ""
        return handler


class StaticRequest(aiohttp.ClientRequest):
    RESPONSES = None

    async def send(self, conn: "Connection") -> "ClientResponse":
        self.response = self.response_class(
            self.method,
            self.original_url,
            writer=self._writer,
            continue100=self._continue,
            timer=self._timer,
            request_info=self.request_info,
            traces=self._traces,
            loop=self.loop,
            session=self._session,
        )
        path = self.original_url.path
        self.response.static_body = StaticRequest.RESPONSES.response(path)
        return self.response


# we use EmptyStreamReader here because it overrides all methods with
# no-op implementations that we need.
class StaticStreamReader(aiohttp.streams.EmptyStreamReader):
    def __init__(self, body):
        super().__init__()
        self.body = body

    async def read(self, n: int = -1) -> bytes:
        return self.body.encode("utf-8")


class StaticResponse(aiohttp.ClientResponse):
    def __init__(
        self,
        method: str,
        url: URL,
        *,
        writer: "asyncio.Task[None]",
        continue100: Optional["asyncio.Future[bool]"],
        timer: BaseTimerContext,
        request_info: RequestInfo,
        traces: list["Trace"],
        loop: asyncio.AbstractEventLoop,
        session: "ClientSession",
    ) -> None:
        super().__init__(
            method,
            url,
            writer=writer,
            continue100=continue100,
            timer=timer,
            request_info=request_info,
            traces=traces,
            loop=loop,
            session=session,
        )
        self.static_body = None

    async def start(self, connection: "Connection") -> "ClientResponse":
        self._closed = False
        self._protocol = connection.protocol
        self._connection = connection
        self._headers = CIMultiDictProxy(CIMultiDict())
        self.content = StaticStreamReader(self.static_body)
        self.status = 200
        return self


class ResponseMatcher:
    def __init__(self, responses):
        self.logger = logging.getLogger(__name__)
        self.responses = []

        for response in responses:
            path = response["path"]
            if path == "*":
                matcher = ResponseMatcher.always()
            elif path.startswith("*"):
                matcher = ResponseMatcher.endswith(path[1:])
            elif path.endswith("*"):
                matcher = ResponseMatcher.startswith(path[:-1])
            else:
                matcher = ResponseMatcher.equals(path)

            body = json.dumps(response["body"])

            self.responses.append((path, matcher, body))

    @staticmethod
    def always():
        def f(p):
            return True

        return f

    @staticmethod
    def startswith(path_pattern):
        def f(p):
            return p.startswith(path_pattern)

        return f

    @staticmethod
    def endswith(path_pattern):
        def f(p):
            return p.endswith(path_pattern)

        return f

    @staticmethod
    def equals(path_pattern):
        def f(p):
            return p == path_pattern

        return f

    def response(self, path):
        for path_pattern, matcher, body in self.responses:
            if matcher(path):
                self.logger.debug("Path pattern [%s] matches path [%s].", path_pattern, path)
                return body


class RallyTCPConnector(aiohttp.TCPConnector):
    def __init__(self, *args, **kwargs):
        self.client_id = kwargs.pop("client_id", None)
        self.logger = logging.getLogger(__name__)
        super().__init__(*args, **kwargs)

    async def _resolve_host(self, *args, **kwargs):
        hosts = await super()._resolve_host(*args, **kwargs)
        self.logger.debug("client id [%s] resolved hosts [{%s}]", self.client_id, hosts)
        # super()._resolve_host() does actually return all the IPs a given name resolves to, but the underlying
        # super()._create_direct_connection() logic only ever selects the first succesful host from this list from which
        # to establish a connection
        #
        # here we use the factory assigned client_id to deterministically return a IP from this list, which we then swap
        # to the beginning of the list to evenly distribute connections across _all_ clients
        # see https://github.com/elastic/rally/issues/1598
        idx = self.client_id % len(hosts)
        host = hosts[idx]
        self.logger.debug("client id [%s] selected host [{%s}]", self.client_id, host)
        # swap order of hosts
        hosts[0], hosts[idx] = hosts[idx], hosts[0]
        return hosts


class RallyAiohttpHttpNode(AiohttpHttpNode):
    def __init__(self, config):
        super().__init__(config)
        self._loop = None
        self.client_id = None
        self.trace_configs = None
        self.enable_cleanup_closed = False
        self._static_responses = None
        self._request_class = aiohttp.ClientRequest
        self._response_class = aiohttp.ClientResponse

    @property
    def static_responses(self):
        return self._static_responses

    @static_responses.setter
    def static_responses(self, static_responses):
        self._static_responses = static_responses
        if self._static_responses:
            # read static responses once and reuse them
            if not StaticRequest.RESPONSES:
                with open(io.normalize_path(self._static_responses)) as f:
                    StaticRequest.RESPONSES = ResponseMatcher(json.load(f))

            self._request_class = StaticRequest
            self._response_class = StaticResponse

    def _create_aiohttp_session(self):
        if self._loop is None:
            self._loop = asyncio.get_running_loop()

        if self._static_responses:
            connector = StaticConnector(limit_per_host=self._connections_per_node, enable_cleanup_closed=self.enable_cleanup_closed)
        else:
            connector = RallyTCPConnector(
                limit_per_host=self._connections_per_node,
                use_dns_cache=True,
                ssl=self._ssl_context,
                enable_cleanup_closed=self.enable_cleanup_closed,
                client_id=self.client_id,
            )

        self.session = aiohttp.ClientSession(
            headers=self.headers,
            auto_decompress=True,
            loop=self._loop,
            cookie_jar=aiohttp.DummyCookieJar(),
            request_class=self._request_class,
            response_class=self._response_class,
            connector=connector,
            trace_configs=self.trace_configs,
        )


class RallyAsyncTransport(AsyncTransport):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, node_class=RallyAiohttpHttpNode, **kwargs)


class RallyIlmClient(IlmClient):
    async def put_lifecycle(self, *args, **kwargs):
        """
        The 'elasticsearch-py' 8.x method signature renames the 'policy' param to 'name', and the previously so-called
        'body' param becomes 'policy'
        """
        if args:
            kwargs["name"] = args[0]

        if body := kwargs.pop("body", None):
            kwargs["policy"] = body.get("policy", {})
        # pylint: disable=missing-kwoa
        return await IlmClient.put_lifecycle(self, **kwargs)


class RallyAsyncElasticsearch(AsyncElasticsearch, RequestContextHolder):
    def __init__(self, *args, **kwargs):
        distribution_version = kwargs.pop("distribution_version", None)
        distribution_flavor = kwargs.pop("distribution_flavor", None)
        super().__init__(*args, **kwargs)
        # skip verification at this point; we've already verified this earlier with the synchronous client.
        # The async client is used in the hot code path and we use customized overrides (such as that we don't
        # parse response bodies in some cases for performance reasons, e.g. when using the bulk API).
        self._verified_elasticsearch = True
        self.distribution_version = distribution_version
        self.distribution_flavor = distribution_flavor

        # some ILM method signatures changed in 'elasticsearch-py' 8.x,
        # so we override method(s) here to provide BWC for any custom
        # runners that aren't using the new kwargs
        self.ilm = RallyIlmClient(self)

    @property
    def is_serverless(self):
        return versions.is_serverless(self.distribution_flavor)

    def options(self, *args, **kwargs):
        new_self = super().options(*args, **kwargs)
        new_self.distribution_version = self.distribution_version
        new_self.distribution_flavor = self.distribution_flavor
        return new_self

    async def perform_request(
        self,
        method: str,
        path: str,
        *,
        params: Optional[Mapping[str, Any]] = None,
        headers: Optional[Mapping[str, str]] = None,
        body: Optional[Any] = None,
    ) -> ApiResponse[Any]:
        # We need to ensure that we provide content-type and accept headers
        if body is not None:
            if headers is None:
                headers = {"content-type": "application/json", "accept": "application/json"}
            else:
                if headers.get("content-type") is None:
                    headers["content-type"] = "application/json"
                if headers.get("accept") is None:
                    headers["accept"] = "application/json"

        if headers:
            request_headers = self._headers.copy()
            request_headers.update(headers)
        else:
            request_headers = self._headers

        # Converts all parts of a Accept/Content-Type headers
        # from application/X -> application/vnd.elasticsearch+X
        # see https://github.com/elastic/elasticsearch/issues/51816
        # Not applicable to serverless
        if not self.is_serverless:
            if versions.is_version_identifier(self.distribution_version) and (
                versions.Version.from_string(self.distribution_version) >= versions.Version.from_string("8.0.0")
            ):
                _mimetype_header_to_compat("Accept", request_headers)
                _mimetype_header_to_compat("Content-Type", request_headers)

        if params:
            target = f"{path}?{_quote_query(params)}"
        else:
            target = path

        meta, resp_body = await self.transport.perform_request(
            method,
            target,
            headers=request_headers,
            body=body,
            request_timeout=self._request_timeout,
            max_retries=self._max_retries,
            retry_on_status=self._retry_on_status,
            retry_on_timeout=self._retry_on_timeout,
            client_meta=self._client_meta,
        )

        # HEAD with a 404 is returned as a normal response
        # since this is used as an 'exists' functionality.
        if not (method == "HEAD" and meta.status == 404) and (
            not 200 <= meta.status < 299
            and (self._ignore_status is DEFAULT or self._ignore_status is None or meta.status not in self._ignore_status)
        ):
            message = str(resp_body)

            # If the response is an error response try parsing
            # the raw Elasticsearch error before raising.
            if isinstance(resp_body, dict):
                try:
                    error = resp_body.get("error", message)
                    if isinstance(error, dict) and "type" in error:
                        error = error["type"]
                    message = error
                except (ValueError, KeyError, TypeError):
                    pass

            raise HTTP_EXCEPTIONS.get(meta.status, ApiError)(message=message, meta=meta, body=resp_body)

        # 'Warning' headers should be reraised as 'ElasticsearchWarning'
        if "warning" in meta.headers:
            warning_header = (meta.headers.get("warning") or "").strip()
            warning_messages: Iterable[str] = _WARNING_RE.findall(warning_header) or (warning_header,)
            stacklevel = warn_stacklevel()
            for warning_message in warning_messages:
                warnings.warn(
                    warning_message,
                    category=ElasticsearchWarning,
                    stacklevel=stacklevel,
                )

        if method == "HEAD":
            response = HeadApiResponse(meta=meta)
        elif isinstance(resp_body, dict):
            response = ObjectApiResponse(body=resp_body, meta=meta)  # type: ignore[assignment]
        elif isinstance(resp_body, list):
            response = ListApiResponse(body=resp_body, meta=meta)  # type: ignore[assignment]
        elif isinstance(resp_body, str):
            response = TextApiResponse(  # type: ignore[assignment]
                body=resp_body,
                meta=meta,
            )
        elif isinstance(resp_body, bytes):
            response = BinaryApiResponse(body=resp_body, meta=meta)  # type: ignore[assignment]
        else:
            response = ApiResponse(body=resp_body, meta=meta)  # type: ignore[assignment]

        return response
