#  BSD 3-Clause License
#
#  Copyright (c) 2012, the Sentry Team, see AUTHORS for more details
#  Copyright (c) 2019, Elasticsearch BV
#  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE


from __future__ import absolute_import

import asyncio
import functools
from typing import Dict, Optional

import starlette
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.routing import Match, Mount
from starlette.types import ASGIApp, Message

import elasticapm
import elasticapm.instrumentation.control
from elasticapm.base import Client, get_client
from elasticapm.conf import constants
from elasticapm.contrib.asyncio.traces import set_context
from elasticapm.contrib.starlette.utils import get_body, get_data_from_request, get_data_from_response
from elasticapm.utils.disttracing import TraceParent
from elasticapm.utils.encoding import long_field
from elasticapm.utils.logging import get_logger

logger = get_logger("elasticapm.errors.client")


def make_apm_client(config: Optional[Dict] = None, client_cls=Client, **defaults) -> Client:
    """Builds ElasticAPM client.

    Args:
        config (dict): Dictionary of Client configuration. All keys must be uppercase. See `elasticapm.conf.Config`.
        client_cls (Client): Must be Client or its child.
        **defaults: Additional parameters for Client. See `elasticapm.base.Client`

    Returns:
        Client
    """
    if "framework_name" not in defaults:
        defaults["framework_name"] = "starlette"
        defaults["framework_version"] = starlette.__version__

    return client_cls(config, **defaults)


class ElasticAPM:
    """
    Starlette / FastAPI middleware for Elastic APM capturing.

    >>> apm = make_apm_client({
        >>> 'SERVICE_NAME': 'myapp',
        >>> 'DEBUG': True,
        >>> 'SERVER_URL': 'http://localhost:8200',
        >>> 'CAPTURE_HEADERS': True,
        >>> 'CAPTURE_BODY': 'all'
    >>> })

    >>> app.add_middleware(ElasticAPM, client=apm)

    Pass an arbitrary SERVICE_NAME and SECRET_TOKEN::

    >>> elasticapm = ElasticAPM(app, service_name='myapp', secret_token='asdasdasd')

    Pass an explicit client (don't pass in additional options in this case)::

    >>> elasticapm = ElasticAPM(app, client=client)

    Capture an exception::

    >>> try:
    >>>     1 / 0
    >>> except ZeroDivisionError:
    >>>     elasticapm.capture_exception()

    Capture a message::

    >>> elasticapm.capture_message('hello, world!')
    """

    def __init__(self, app: ASGIApp, client: Optional[Client] = None, **kwargs) -> None:
        """

        Args:
            app (ASGIApp): Starlette app
            client (Client): ElasticAPM Client
        """
        if client:
            self.client = client
        else:
            self.client = get_client()
        if not self.client:
            self.client = make_apm_client(**kwargs)

        if self.client.config.instrument and self.client.config.enabled:
            elasticapm.instrumentation.control.instrument()

        # If we ever make this a general-use ASGI middleware we should use
        # `asgiref.compatibility.guarantee_single_callable(app)` here
        self.app = app

    async def __call__(self, scope, receive, send):
        """
        Args:
            scope: ASGI scope dictionary
            receive: receive awaitable callable
            send: send awaitable callable
        """
        # we only handle the http scope, skip anything else.
        if scope["type"] != "http" or (scope["type"] == "http" and self.client.should_ignore_url(scope["path"])):
            await self.app(scope, receive, send)
            return

        @functools.wraps(send)
        async def wrapped_send(message) -> None:
            if message.get("type") == "http.response.start":
                await set_context(
                    lambda: get_data_from_response(message, self.client.config, constants.TRANSACTION), "response"
                )
                result = "HTTP {}xx".format(message["status"] // 100)
                elasticapm.set_transaction_result(result, override=False)
                elasticapm.set_transaction_outcome(http_status_code=message["status"], override=False)
            await send(message)

        _mocked_receive = None
        _request_receive = None

        # begin the transaction before capturing the body to get that time accounted
        trace_parent = TraceParent.from_headers(dict(Headers(scope=scope)))
        self.client.begin_transaction("request", trace_parent=trace_parent)

        if self.client.config.capture_body != "off":

            # When we consume the body from receive, we replace the streaming
            # mechanism with a mocked version -- this workaround came from
            # https://github.com/encode/starlette/issues/495#issuecomment-513138055
            body = []
            while True:
                message = await receive()
                if not message:
                    break
                if message["type"] == "http.request":
                    b = message.get("body", b"")
                    if b:
                        body.append(b)
                    if not message.get("more_body", False):
                        break
                if message["type"] == "http.disconnect":
                    break

            joined_body = b"".join(body)

            async def mocked_receive() -> Message:
                await asyncio.sleep(0)
                return {"type": "http.request", "body": long_field(joined_body)}

            _mocked_receive = mocked_receive

            async def request_receive() -> Message:
                await asyncio.sleep(0)
                return {"type": "http.request", "body": joined_body}

            _request_receive = request_receive

        request = Request(scope, receive=_mocked_receive or receive)
        await self._request_started(request)

        # We don't end the transaction here, we rely on the starlette
        # instrumentation of ServerErrorMiddleware to end the transaction
        try:
            await self.app(scope, _request_receive or receive, wrapped_send)
            elasticapm.set_transaction_outcome(constants.OUTCOME.SUCCESS, override=False)
        except Exception:
            await self.capture_exception(
                context={"request": await get_data_from_request(request, self.client.config, constants.ERROR)}
            )
            elasticapm.set_transaction_result("HTTP 5xx", override=False)
            elasticapm.set_transaction_outcome(constants.OUTCOME.FAILURE, override=False)
            elasticapm.set_context({"status_code": 500}, "response")

            raise

    async def capture_exception(self, *args, **kwargs) -> None:
        """Captures your exception.

        Args:
            *args:
            **kwargs:
        """
        self.client.capture_exception(*args, **kwargs)

    async def capture_message(self, *args, **kwargs) -> None:
        """Captures your message.

        Args:
            *args: Whatever
            **kwargs: Whatever
        """
        self.client.capture_message(*args, **kwargs)

    async def _request_started(self, request: Request) -> None:
        """Captures the begin of the request processing to APM.

        Args:
            request (Request)
        """
        # When we consume the body, we replace the streaming mechanism with
        # a mocked version -- this workaround came from
        # https://github.com/encode/starlette/issues/495#issuecomment-513138055
        # and we call the workaround here to make sure that regardless of
        # `capture_body` settings, we will have access to the body if we need it.
        if self.client.config.capture_body != "off":
            await get_body(request)

        await set_context(lambda: get_data_from_request(request, self.client.config, constants.TRANSACTION), "request")
        transaction_name = self.get_route_name(request) or request.url.path
        elasticapm.set_transaction_name("{} {}".format(request.method, transaction_name), override=False)

    def get_route_name(self, request: Request) -> str:
        app = request.app
        scope = request.scope
        routes = app.routes
        route_name = self._get_route_name(scope, routes)

        # Starlette magically redirects requests if the path matches a route name with a trailing slash
        # appended or removed. To not spam the transaction names list, we do the same here and put these
        # redirects all in the same "redirect trailing slashes" transaction name
        if not route_name and app.router.redirect_slashes and scope["path"] != "/":
            redirect_scope = dict(scope)
            if scope["path"].endswith("/"):
                redirect_scope["path"] = scope["path"][:-1]
                trim = True
            else:
                redirect_scope["path"] = scope["path"] + "/"
                trim = False

            route_name = self._get_route_name(redirect_scope, routes)
            if route_name is not None:
                route_name = route_name + "/" if trim else route_name[:-1]
        return route_name

    def _get_route_name(self, scope, routes, route_name=None):
        for route in routes:
            match, child_scope = route.matches(scope)
            if match == Match.FULL:
                route_name = route.path
                child_scope = {**scope, **child_scope}
                if isinstance(route, Mount) and route.routes:
                    child_route_name = self._get_route_name(child_scope, route.routes, route_name)
                    if child_route_name is None:
                        route_name = None
                    else:
                        route_name += child_route_name
                return route_name
            elif match == Match.PARTIAL and route_name is None:
                route_name = route.path
