#  BSD 3-Clause License
#
#  Copyright (c) 2019, Elasticsearch BV
#  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  * Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
#  * Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
#  * Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import binascii
import ctypes
import itertools
import random
import re
from typing import Dict, Optional

from elasticapm.conf import constants
from elasticapm.utils.logging import get_logger

logger = get_logger("elasticapm.utils")


class TraceParent(object):
    __slots__ = ("version", "trace_id", "span_id", "trace_options", "tracestate", "tracestate_dict", "is_legacy")

    def __init__(
        self,
        version: int,
        trace_id: str,
        span_id: str,
        trace_options: "TracingOptions",
        tracestate: Optional[str] = None,
        is_legacy: bool = False,
    ) -> None:
        self.version: int = version
        self.trace_id: str = trace_id
        self.span_id: str = span_id
        self.trace_options: TracingOptions = trace_options
        self.is_legacy: bool = is_legacy
        self.tracestate: Optional[str] = tracestate
        self.tracestate_dict = self._parse_tracestate(tracestate)

    def copy_from(
        self,
        version: int = None,
        trace_id: str = None,
        span_id: str = None,
        trace_options: "TracingOptions" = None,
        tracestate: str = None,
    ):
        return TraceParent(
            version or self.version,
            trace_id or self.trace_id,
            span_id or self.span_id,
            trace_options or self.trace_options,
            tracestate or self.tracestate,
        )

    def to_string(self) -> str:
        return "{:02x}-{}-{}-{:02x}".format(self.version, self.trace_id, self.span_id, self.trace_options.asByte)

    def to_ascii(self) -> bytes:
        return self.to_string().encode("ascii")

    def to_binary(self) -> bytes:
        return b"".join(
            [
                (self.version).to_bytes(1, byteorder="big"),
                (0).to_bytes(1, byteorder="big"),
                bytes.fromhex(self.trace_id),
                (1).to_bytes(1, byteorder="big"),
                bytes.fromhex(self.span_id),
                (2).to_bytes(1, byteorder="big"),
                (self.trace_options.asByte).to_bytes(1, byteorder="big"),
            ]
        )

    @classmethod
    def new(cls, transaction_id: str, is_sampled: bool) -> "TraceParent":
        return cls(
            version=constants.TRACE_CONTEXT_VERSION,
            trace_id="%032x" % random.getrandbits(128),
            span_id=transaction_id,
            trace_options=TracingOptions(recorded=is_sampled),
        )

    @classmethod
    def from_string(
        cls, traceparent_string: str, tracestate_string: Optional[str] = None, is_legacy: bool = False
    ) -> Optional["TraceParent"]:
        try:
            parts = traceparent_string.split("-")
            version, trace_id, span_id, trace_flags = parts[:4]
        except ValueError:
            logger.debug("Invalid traceparent header format, value %s", traceparent_string)
            return
        try:
            version = int(version, 16)
            if version == 255:
                raise ValueError()
        except ValueError:
            logger.debug("Invalid version field, value %s", version)
            return
        try:
            tracing_options = TracingOptions()
            tracing_options.asByte = int(trace_flags, 16)
        except ValueError:
            logger.debug("Invalid trace-options field, value %s", trace_flags)
            return
        return TraceParent(version, trace_id, span_id, tracing_options, tracestate_string, is_legacy)

    @classmethod
    def from_headers(
        cls,
        headers: dict,
        header_name: str = constants.TRACEPARENT_HEADER_NAME,
        legacy_header_name: str = constants.TRACEPARENT_LEGACY_HEADER_NAME,
        tracestate_header_name: str = constants.TRACESTATE_HEADER_NAME,
    ) -> Optional["TraceParent"]:
        tracestate = cls.merge_duplicate_headers(headers, tracestate_header_name)
        if header_name in headers:
            return TraceParent.from_string(headers[header_name], tracestate, is_legacy=False)
        elif legacy_header_name in headers:
            return TraceParent.from_string(headers[legacy_header_name], tracestate, is_legacy=False)
        else:
            return None

    @classmethod
    def from_binary(cls, data: bytes) -> Optional["TraceParent"]:
        if len(data) != 29:
            logger.debug("Invalid binary traceparent format, length is %d, should be 29, value %r", len(data), data)
            return
        if (
            int.from_bytes(data[1:2], byteorder="big") != 0
            or int.from_bytes(data[18:19], byteorder="big") != 1
            or int.from_bytes(data[27:28], byteorder="big") != 2
        ):
            logger.debug("Invalid binary traceparent format, field identifiers not correct, value %r", data)
            return
        version = int.from_bytes(data[0:1], byteorder="big")
        trace_id = str(binascii.hexlify(data[2:18]), encoding="ascii")
        span_id = str(binascii.hexlify(data[19:27]), encoding="ascii")
        try:
            tracing_options = TracingOptions()
            tracing_options.asByte = int.from_bytes(data[28:29], byteorder="big")
        except ValueError:
            logger.debug("Invalid trace-options field, value %r", data[28:29])
            return
        return TraceParent(version, trace_id, span_id, tracing_options)

    @classmethod
    def merge_duplicate_headers(cls, headers, key):
        """
        HTTP allows multiple values for the same header name. Most WSGI implementations
        merge these values using a comma as separator (this has been confirmed for wsgiref,
        werkzeug, gunicorn and uwsgi). Other implementations may use containers like
        multidict to store headers and have APIs to iterate over all values for a given key.

        This method is provided as a hook for framework integrations to provide their own
        TraceParent implementation. The implementation should return a single string. Multiple
        values for the same key should be merged using a comma as separator.

        :param headers: a dict-like header object
        :param key: header name
        :return: a single string value or None
        """
        # this works for all known WSGI implementations
        if isinstance(headers, list):
            return ",".join([item[1] for item in headers if item[0] == key])
        elif not hasattr(headers, "get") and hasattr(headers, "__iter__"):
            return ",".join([item[1] for item in headers if item[0] == key])
        return headers.get(key)

    def _parse_tracestate(self, tracestate) -> Dict[str, str]:
        """
        Tracestate can contain data from any vendor, made distinct by vendor
        keys. Vendors are comma-separated. The elastic (es) tracestate data is
        made up of key:value pairs, separated by semicolons. It is meant to
        be parsed into a dict.

            tracestate: es=key:value;key:value... , othervendor=<opaque>

        Per https://w3c.github.io/trace-context/#tracestate-header-field-values
        there can be optional whitespace (OWS) between the comma-separated
        list-members.
        """
        if not tracestate:
            return {}
        if "es=" not in tracestate:
            return {}

        ret = {}
        try:
            state = re.search(r"(?:,|^)\s*es=([^,]*?)\s*(?:,|$)", tracestate).group(1).split(";")
        except IndexError:
            return {}
        for keyval in state:
            if not keyval:
                continue
            key, _, val = keyval.partition(":")
            ret[key] = val

        return ret

    def _set_tracestate(self):
        elastic_value = ";".join(["{}:{}".format(k, v) for k, v in self.tracestate_dict.items()])
        # No character validation needed, as we validate in `add_tracestate`. Just validate length.
        if len(elastic_value) > 256:
            logger.debug("Modifications to TraceState would violate length limits, ignoring.")
            raise TraceStateFormatException()
        elastic_state = "es={}".format(elastic_value)
        if not self.tracestate:
            return elastic_state
        else:
            # Remove es=<stuff> from the tracestate, and add the new es state to the end
            otherstate = re.sub(r"(?:,|^)\s*es=([^,]*?)\s*(?:,|$)", "", self.tracestate)
            otherstate = otherstate.lstrip(",")  # in case `es=` was the first entry
            otherstate = re.sub(r",,", ",", otherstate)  # remove potential double commas
            # No validation of `otherstate` required, since we're downstream. We only need to check `es=`
            # since we introduced it, and that validation has already been done at this point.
            if otherstate:
                return "{},{}".format(otherstate.rstrip(","), elastic_state)
            else:
                return elastic_state

    def add_tracestate(self, key, val) -> None:
        """
        Add key/value pair to the tracestate.

        We do most of the validation for valid characters here. We have to make
        sure none of the reserved separators for tracestate are used in our
        key/value pairs, and we also need to check that all characters are
        within the valid range. Checking here means we never have to re-check
        a pair once set, which saves time in the _set_tracestate() function.
        """
        key = str(key)
        val = str(val)
        for bad in (":", ";", ",", "="):
            if bad in key or bad in val:
                logger.debug("New tracestate key/val pair contains invalid character '{}', ignoring.".format(bad))
                return
        for c in itertools.chain(key, val):
            # Tracestate spec only allows for characters between ASCII 0x20 and 0x7E
            if ord(c) < 0x20 or ord(c) > 0x7E:
                logger.debug("Modifications to TraceState would introduce invalid character '{}', ignoring.".format(c))
                return

        oldval = self.tracestate_dict.pop(key, None)
        self.tracestate_dict[key] = val
        try:
            self.tracestate = self._set_tracestate()
        except TraceStateFormatException:
            if oldval is not None:
                self.tracestate_dict[key] = oldval
            else:
                self.tracestate_dict.pop(key)


class TracingOptions_bits(ctypes.LittleEndianStructure):
    _fields_ = [("recorded", ctypes.c_uint8, 1)]


class TracingOptions(ctypes.Union):
    _anonymous_ = ("bit",)
    _fields_ = [("bit", TracingOptions_bits), ("asByte", ctypes.c_uint8)]

    def __init__(self, **kwargs) -> None:
        super(TracingOptions, self).__init__()
        for k, v in kwargs.items():
            setattr(self, k, v)

    def __eq__(self, other):
        return self.asByte == other.asByte


def trace_parent_from_string(traceparent_string, tracestate_string=None, is_legacy=False):
    """
    This is a wrapper function so we can add traceparent generation to the
    public API.
    """
    return TraceParent.from_string(traceparent_string, tracestate_string=tracestate_string, is_legacy=is_legacy)


def trace_parent_from_headers(headers):
    """
    This is a wrapper function so we can add traceparent generation to the
    public API.
    """
    return TraceParent.from_headers(headers)


class TraceStateFormatException(Exception):
    pass
