#  BSD 3-Clause License
#
#  Copyright (c) 2019, Elasticsearch BV
#  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import base64
import datetime
import functools
import json
import os
import platform
import re
import time
import urllib
import warnings
from typing import Optional, TypeVar
from urllib.parse import urlencode

import elasticapm
from elasticapm.base import Client, get_client
from elasticapm.conf import constants
from elasticapm.traces import execution_context
from elasticapm.utils import encoding, get_name_from_func, nested_key
from elasticapm.utils.disttracing import TraceParent
from elasticapm.utils.logging import get_logger

SERVERLESS_HTTP_REQUEST = ("api", "elb")

logger = get_logger("elasticapm.serverless")

COLD_START = True
INSTRUMENTED = False
REGISTER_PARTIAL_TRANSACTIONS = True

_AWSLambdaContextT = TypeVar("_AWSLambdaContextT")


def capture_serverless(func: Optional[callable] = None, **kwargs) -> callable:
    """
    Decorator for instrumenting AWS Lambda functions.

    Example usage:

        from elasticapm import capture_serverless

        @capture_serverless
        def handler(event, context):
            return {"statusCode": r.status_code, "body": "Success!"}

    Please note that when using the APM Layer and setting AWS_LAMBDA_EXEC_WRAPPER this is not required and
    the handler would be instrumented automatically.
    """
    if not func:
        # This allows for `@capture_serverless()` in addition to
        # `@capture_serverless` decorator usage
        return functools.partial(capture_serverless, **kwargs)

    if kwargs:
        warnings.warn(
            PendingDeprecationWarning(
                "Passing keyword arguments to capture_serverless will be deprecated in the next major release."
            )
        )

    name = kwargs.pop("name", None)

    kwargs = prep_kwargs(kwargs)

    global INSTRUMENTED
    client = get_client()
    if not client:
        client = Client(**kwargs)
    if not client.config.debug and client.config.instrument and client.config.enabled and not INSTRUMENTED:
        elasticapm.instrument()
        INSTRUMENTED = True

    @functools.wraps(func)
    def decorated(*args, **kwds):
        if len(args) >= 2:
            # Saving these for request context later
            event, context = args[0:2]
        else:
            event, context = {}, {}

        if not client.config.debug and client.config.instrument and client.config.enabled:
            with _lambda_transaction(func, name, client, event, context) as sls:
                sls.response = func(*args, **kwds)
                return sls.response
        else:
            return func(*args, **kwds)

    return decorated


def prep_kwargs(kwargs=None):
    if kwargs is None:
        kwargs = {}

    # Disable all background threads except for transport
    kwargs["metrics_interval"] = "0ms"
    kwargs["breakdown_metrics"] = False
    if "metrics_sets" not in kwargs and "ELASTIC_APM_METRICS_SETS" not in os.environ:
        # Allow users to override metrics sets
        kwargs["metrics_sets"] = []
    kwargs["central_config"] = False
    kwargs["cloud_provider"] = "none"
    kwargs["framework_name"] = "AWS Lambda"
    if "service_name" not in kwargs and "ELASTIC_APM_SERVICE_NAME" not in os.environ:
        kwargs["service_name"] = os.environ["AWS_LAMBDA_FUNCTION_NAME"]
    if "service_version" not in kwargs and "ELASTIC_APM_SERVICE_VERSION" not in os.environ:
        kwargs["service_version"] = os.environ.get("AWS_LAMBDA_FUNCTION_VERSION")

    return kwargs


def should_normalize_headers(event: dict) -> bool:
    """
    Helper to decide if we should normalize headers or not depending on the event

    Even if the documentation says that headers are lowercased it's not always the case for format version 1.0
    https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html
    """

    request_context = event.get("requestContext", {})
    return ("elb" in request_context or "requestId" in request_context) and "http" not in request_context


class _lambda_transaction(object):
    """
    Context manager for creating transactions around AWS Lambda functions.

    Begins and ends a single transaction, waiting for the transport to flush
    before releasing the context.
    """

    def __init__(
        self, func: callable, name: Optional[str], client: Client, event: dict, context: _AWSLambdaContextT
    ) -> None:
        self.func = func
        self.name = name or get_name_from_func(func)
        self.event = event
        self.context = context
        self.response = None
        self.client = client

    def __enter__(self):
        """
        Transaction setup
        """
        if not isinstance(self.event, dict):
            # When `event` is not a dict, it's likely the output of another AWS
            # service like Step Functions, and is unlikely to be standardized
            # in any way. We just have to rely on our defaults in this case.
            self.event = {}

        headers = self.event.get("headers") or {}
        if headers and should_normalize_headers(self.event):
            normalized_headers = {k.lower(): v for k, v in headers.items()}
        else:
            normalized_headers = headers
        trace_parent = TraceParent.from_headers(normalized_headers)

        global COLD_START
        cold_start = COLD_START
        COLD_START = False

        self.source = "other"
        transaction_type = "request"
        transaction_name = os.environ.get("AWS_LAMBDA_FUNCTION_NAME", self.name)

        self.httpmethod = (
            nested_key(self.event, "requestContext", "httpMethod")
            or nested_key(self.event, "requestContext", "http", "method")
            or nested_key(self.event, "httpMethod")
        )

        if self.httpmethod:  # http request
            if nested_key(self.event, "requestContext", "elb"):
                self.source = "elb"
                resource = "unknown route"
            elif nested_key(self.event, "requestContext", "httpMethod"):
                self.source = "api"
                # API v1
                resource = "/{}{}".format(
                    nested_key(self.event, "requestContext", "stage"),
                    nested_key(self.event, "requestContext", "resourcePath"),
                )
            else:
                self.source = "api"
                # API v2
                route_key = nested_key(self.event, "requestContext", "routeKey")
                route_key = f"/{route_key}" if route_key.startswith("$") else route_key.split(" ", 1)[-1]
                resource = "/{}{}".format(
                    nested_key(self.event, "requestContext", "stage"),
                    route_key,
                )
            transaction_name = "{} {}".format(self.httpmethod, resource)
        elif "Records" in self.event and len(self.event["Records"]) == 1:
            record = self.event["Records"][0]
            if record.get("eventSource") == "aws:s3":  # S3
                self.source = "s3"
                transaction_name = "{} {}".format(record["eventName"], record["s3"]["bucket"]["name"])
            elif record.get("EventSource") == "aws:sns":  # SNS
                self.source = "sns"
                transaction_type = "messaging"
                transaction_name = "RECEIVE {}".format(record["Sns"]["TopicArn"].split(":")[5])
            elif record.get("eventSource") == "aws:sqs":  # SQS
                self.source = "sqs"
                transaction_type = "messaging"
                transaction_name = "RECEIVE {}".format(record["eventSourceARN"].split(":")[5])

        if "Records" in self.event:
            links = [
                TraceParent.from_string(record["messageAttributes"]["traceparent"]["stringValue"])
                for record in self.event["Records"][:1000]
                if "messageAttributes" in record and "traceparent" in record["messageAttributes"]
            ]
        else:
            links = []

        self.client.begin_transaction(transaction_type, trace_parent=trace_parent, links=links)
        elasticapm.set_transaction_name(transaction_name, override=False)
        if self.source in SERVERLESS_HTTP_REQUEST:
            elasticapm.set_context(
                lambda: get_data_from_request(
                    self.event,
                    capture_body=self.client.config.capture_body in ("transactions", "all"),
                    capture_headers=self.client.config.capture_headers,
                ),
                "request",
            )
        self.set_metadata_and_context(cold_start)
        self.send_partial_transaction()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        """
        Transaction teardown
        """
        if self.response and isinstance(self.response, dict):
            elasticapm.set_context(
                lambda: get_data_from_response(self.response, capture_headers=self.client.config.capture_headers),
                "response",
            )
            if "statusCode" in self.response:
                try:
                    result = "HTTP {}xx".format(int(self.response["statusCode"]) // 100)
                    elasticapm.set_transaction_result(result, override=False)
                    if result == "HTTP 5xx":
                        elasticapm.set_transaction_outcome(outcome="failure", override=False)
                except ValueError:
                    logger.warning("Lambda function's statusCode was not formed as an int. Assuming 5xx result.")
                    elasticapm.set_transaction_result("HTTP 5xx", override=False)
                    elasticapm.set_transaction_outcome(outcome="failure", override=False)
        if exc_val:
            self.client.capture_exception(exc_info=(exc_type, exc_val, exc_tb), handled=False)
            if self.source in SERVERLESS_HTTP_REQUEST:
                elasticapm.set_transaction_result("HTTP 5xx", override=False)
                elasticapm.set_transaction_outcome(http_status_code=500, override=False)
                elasticapm.set_context({"status_code": 500}, "response")
            else:
                elasticapm.set_transaction_result("failure", override=False)
                elasticapm.set_transaction_outcome(outcome="failure", override=False)

        self.client.end_transaction()
        # Collect any custom+prometheus metrics if enabled
        self.client.metrics.collect()

        try:
            logger.debug("Flushing elasticapm data")
            self.client._transport.flush()
            logger.debug("Flush complete")
        except ValueError:
            logger.warning("Flush timed out")

    def set_metadata_and_context(self, coldstart: bool) -> None:
        """
        Process the metadata and context fields for this request
        """
        metadata = {}
        cloud_context = {"origin": {"provider": "aws"}}
        service_context = {}
        message_context = {}

        faas = {}
        faas["coldstart"] = coldstart
        faas["trigger"] = {"type": "other"}
        faas["execution"] = self.context.aws_request_id
        arn = self.context.invoked_function_arn
        if len(arn.split(":")) > 7:
            arn = ":".join(arn.split(":")[:7])
        faas["id"] = arn
        faas["name"] = os.environ.get("AWS_LAMBDA_FUNCTION_NAME")
        faas["version"] = os.environ.get("AWS_LAMBDA_FUNCTION_VERSION")

        if self.source == "api":
            faas["trigger"]["type"] = "http"
            faas["trigger"]["request_id"] = self.event["requestContext"]["requestId"]
            service_context["origin"] = {"name": self.event["requestContext"]["domainName"]}
            service_context["origin"]["id"] = self.event["requestContext"]["apiId"]
            service_context["origin"]["version"] = self.event.get("version", "1.0")
            cloud_context["origin"] = {}
            cloud_context["origin"]["service"] = {"name": "api gateway"}
            if ".lambda-url." in self.event["requestContext"]["domainName"]:
                cloud_context["origin"]["service"]["name"] = "lambda url"
            cloud_context["origin"]["account"] = {"id": self.event["requestContext"]["accountId"]}
            cloud_context["origin"]["provider"] = "aws"
        elif self.source == "elb":
            elb_target_group_arn = self.event["requestContext"]["elb"]["targetGroupArn"]
            faas["trigger"]["type"] = "http"
            service_context["origin"] = {"name": elb_target_group_arn.split(":")[5].split("/")[1]}
            service_context["origin"]["id"] = elb_target_group_arn
            cloud_context["origin"] = {}
            cloud_context["origin"]["service"] = {"name": "elb"}
            cloud_context["origin"]["account"] = {"id": elb_target_group_arn.split(":")[4]}
            cloud_context["origin"]["region"] = elb_target_group_arn.split(":")[3]
            cloud_context["origin"]["provider"] = "aws"
        elif self.source == "sqs":
            record = self.event["Records"][0]
            faas["trigger"]["type"] = "pubsub"
            faas["trigger"]["request_id"] = record["messageId"]
            service_context["origin"] = {}
            service_context["origin"]["name"] = record["eventSourceARN"].split(":")[5]
            service_context["origin"]["id"] = record["eventSourceARN"]
            cloud_context["origin"] = {}
            cloud_context["origin"]["service"] = {"name": "sqs"}
            cloud_context["origin"]["region"] = record["awsRegion"]
            cloud_context["origin"]["account"] = {"id": record["eventSourceARN"].split(":")[4]}
            cloud_context["origin"]["provider"] = "aws"
            message_context["queue"] = {"name": service_context["origin"]["name"]}
            if "SentTimestamp" in record["attributes"]:
                message_context["age"] = {"ms": int((time.time() * 1000) - int(record["attributes"]["SentTimestamp"]))}
            if self.client.config.capture_body in ("transactions", "all") and "body" in record:
                message_context["body"] = record["body"]
            if self.client.config.capture_headers and record.get("messageAttributes"):
                headers = {}
                for k, v in record["messageAttributes"].items():
                    if v and v.get("stringValue"):
                        headers[k] = v.get("stringValue")
                if headers:
                    message_context["headers"] = headers
        elif self.source == "sns":
            record = self.event["Records"][0]
            faas["trigger"]["type"] = "pubsub"
            faas["trigger"]["request_id"] = record["Sns"]["TopicArn"]
            service_context["origin"] = {}
            service_context["origin"]["name"] = record["Sns"]["TopicArn"].split(":")[5]
            service_context["origin"]["id"] = record["Sns"]["TopicArn"]
            service_context["origin"]["version"] = record["EventVersion"]
            service_context["origin"]["service"] = {"name": "sns"}
            cloud_context["origin"] = {}
            cloud_context["origin"]["region"] = record["Sns"]["TopicArn"].split(":")[3]
            cloud_context["origin"]["account"] = {"id": record["Sns"]["TopicArn"].split(":")[4]}
            cloud_context["origin"]["provider"] = "aws"
            message_context["queue"] = {"name": service_context["origin"]["name"]}
            if "Timestamp" in record["Sns"]:
                message_context["age"] = {
                    "ms": int(
                        (
                            datetime.datetime.now()
                            - datetime.datetime.strptime(record["Sns"]["Timestamp"], r"%Y-%m-%dT%H:%M:%S.%fZ")
                        ).total_seconds()
                        * 1000
                    )
                }
            if self.client.config.capture_body in ("transactions", "all") and "Message" in record["Sns"]:
                message_context["body"] = record["Sns"]["Message"]
            if self.client.config.capture_headers and record["Sns"].get("MessageAttributes"):
                headers = {}
                for k, v in record["Sns"]["MessageAttributes"].items():
                    if v and v.get("Type") == "String":
                        headers[k] = v.get("Value")
                if headers:
                    message_context["headers"] = headers
        elif self.source == "s3":
            record = self.event["Records"][0]
            faas["trigger"]["type"] = "datasource"
            faas["trigger"]["request_id"] = record["responseElements"]["x-amz-request-id"]
            service_context["origin"] = {}
            service_context["origin"]["name"] = record["s3"]["bucket"]["name"]
            service_context["origin"]["id"] = record["s3"]["bucket"]["arn"]
            service_context["origin"]["version"] = record["eventVersion"]
            cloud_context["origin"] = {}
            cloud_context["origin"]["service"] = {"name": "s3"}
            cloud_context["origin"]["region"] = record["awsRegion"]
            cloud_context["origin"]["provider"] = "aws"

        metadata["service"] = {}
        metadata["service"]["name"] = self.client.config.service_name
        metadata["service"]["framework"] = {"name": "AWS Lambda"}
        metadata["service"]["runtime"] = {
            "name": os.environ.get("AWS_EXECUTION_ENV"),
            "version": platform.python_version(),
        }
        metadata["service"]["version"] = self.client.config.service_version
        metadata["service"]["node"] = {"configured_name": os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")}
        # This is the one piece of metadata that requires deep merging. We add it manually
        # here to avoid having to deep merge in _transport.add_metadata()
        node = nested_key(self.client.get_service_info(), "node")
        if node:
            metadata["service"]["node"] = node

        metadata["cloud"] = {}
        metadata["cloud"]["provider"] = "aws"
        metadata["cloud"]["region"] = os.environ.get("AWS_REGION")
        metadata["cloud"]["service"] = {"name": "lambda"}
        metadata["cloud"]["account"] = {"id": arn.split(":")[4]}

        elasticapm.set_context(cloud_context, "cloud")
        elasticapm.set_context(service_context, "service")
        # faas doesn't actually belong in context, but we handle this in to_dict
        elasticapm.set_context(faas, "faas")
        if message_context:
            elasticapm.set_context(message_context, "message")
        self.client.add_extra_metadata(metadata)

    def send_partial_transaction(self) -> None:
        """
        We synchronously send the (partial) transaction to the Lambda Extension
        so that the transaction can be reported even if the lambda runtime times
        out before we can report the transaction.

        This is pretty specific to the HTTP transport. If we ever add other
        transports, we will need to clean this up.
        """
        global REGISTER_PARTIAL_TRANSACTIONS
        if (
            REGISTER_PARTIAL_TRANSACTIONS
            and os.environ.get("ELASTIC_APM_LAMBDA_APM_SERVER")
            and ("localhost" in self.client.config.server_url or "127.0.0.1" in self.client.config.server_url)
        ):
            transaction = execution_context.get_transaction()
            transport = self.client._transport
            logger.debug("Sending partial transaction and early metadata to the lambda extension...")
            data = transport._json_serializer({"metadata": self.client.build_metadata()}) + "\n"
            data += transport._json_serializer({"transaction": transaction.to_dict()})
            partial_transaction_url = urllib.parse.urljoin(
                self.client.config.server_url
                if self.client.config.server_url.endswith("/")
                else self.client.config.server_url + "/",
                "register/transaction",
            )
            try:
                transaction.pause_sampling = True
                transport.send(
                    data,
                    custom_url=partial_transaction_url,
                    custom_headers={
                        "x-elastic-aws-request-id": self.context.aws_request_id,
                        "Content-Type": "application/vnd.elastic.apm.transaction+ndjson",
                    },
                )
            except Exception as e:
                if re.match(r"HTTP [4,5]\d\d", str(e)):
                    REGISTER_PARTIAL_TRANSACTIONS = False
                    logger.info(
                        "APM Lambda Extension does not support partial transactions. "
                        "Disabling partial transaction registration."
                    )
                else:
                    logger.warning("Failed to send partial transaction to APM Lambda Extension", exc_info=True)
            finally:
                transaction.pause_sampling = False


def get_data_from_request(event: dict, capture_body: bool = False, capture_headers: bool = True) -> dict:
    """
    Capture context data from API gateway event
    """
    result = {}
    if capture_headers and "headers" in event:
        result["headers"] = event["headers"]

    method = (
        nested_key(event, "requestContext", "httpMethod")
        or nested_key(event, "requestContext", "http", "method")
        or nested_key(event, "httpMethod")
    )

    if not method:
        # Not API Gateway
        return result

    result["method"] = method
    if method in constants.HTTP_WITH_BODY and "body" in event:
        body = event["body"]
        if capture_body:
            if event.get("isBase64Encoded"):
                body = base64.b64decode(body)
            else:
                try:
                    jsonbody = json.loads(body)
                    body = jsonbody
                except Exception:
                    pass

        if body is not None:
            result["body"] = body if capture_body else "[REDACTED]"

    result["url"] = get_url_dict(event)
    return result


def get_data_from_response(response: dict, capture_headers: bool = True) -> dict:
    """
    Capture response data from lambda return
    """
    result = {}

    if "statusCode" in response:
        try:
            result["status_code"] = int(response["statusCode"])
        except ValueError:
            # statusCode wasn't formed as an int
            # we don't log here, as we will have already logged at transaction.result handling
            result["status_code"] = 500

    if capture_headers and "headers" in response:
        result["headers"] = response["headers"]
    return result


def get_url_dict(event: dict) -> dict:
    """
    Reconstruct URL from API Gateway
    """
    headers = event.get("headers") or {}
    protocol = headers.get("X-Forwarded-Proto", headers.get("x-forwarded-proto", "https"))
    host = headers.get("Host", headers.get("host", ""))
    stage = nested_key(event, "requestContext", "stage") or ""
    raw_path = event.get("rawPath", "")
    if stage:
        stage = "/" + stage
        raw_path = raw_path.split(stage)[-1]

    path = event.get("path", raw_path)
    port = headers.get("X-Forwarded-Port", headers.get("x-forwarded-port"))
    query = ""
    if "rawQueryString" in event:
        query = event["rawQueryString"]
    elif event.get("queryStringParameters"):
        if stage:  # api requires parameters encoding to build correct url
            query = "?" + urlencode(event["queryStringParameters"])
        else:  # for elb we do not have the stage
            query = "?" + "&".join(["{}={}".format(k, v) for k, v in event["queryStringParameters"].items()])

    url = protocol + "://" + host + stage + path + query

    url_dict = {
        "full": encoding.keyword_field(url),
        "protocol": protocol,
        "hostname": encoding.keyword_field(host),
        "pathname": encoding.keyword_field(stage + path),
    }

    if port:
        url_dict["port"] = port
    if query:
        url_dict["search"] = encoding.keyword_field(query)
    return url_dict
