elasticapm/utils/disttracing.py (203 lines of code) (raw):
# BSD 3-Clause License
#
# Copyright (c) 2019, Elasticsearch BV
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import binascii
import ctypes
import itertools
import random
import re
from typing import Dict, Optional
from elasticapm.conf import constants
from elasticapm.utils.logging import get_logger
logger = get_logger("elasticapm.utils")
class TraceParent(object):
__slots__ = ("version", "trace_id", "span_id", "trace_options", "tracestate", "tracestate_dict", "is_legacy")
def __init__(
self,
version: int,
trace_id: str,
span_id: str,
trace_options: "TracingOptions",
tracestate: Optional[str] = None,
is_legacy: bool = False,
) -> None:
self.version: int = version
self.trace_id: str = trace_id
self.span_id: str = span_id
self.trace_options: TracingOptions = trace_options
self.is_legacy: bool = is_legacy
self.tracestate: Optional[str] = tracestate
self.tracestate_dict = self._parse_tracestate(tracestate)
def copy_from(
self,
version: int = None,
trace_id: str = None,
span_id: str = None,
trace_options: "TracingOptions" = None,
tracestate: str = None,
):
return TraceParent(
version or self.version,
trace_id or self.trace_id,
span_id or self.span_id,
trace_options or self.trace_options,
tracestate or self.tracestate,
)
def to_string(self) -> str:
return "{:02x}-{}-{}-{:02x}".format(self.version, self.trace_id, self.span_id, self.trace_options.asByte)
def to_ascii(self) -> bytes:
return self.to_string().encode("ascii")
def to_binary(self) -> bytes:
return b"".join(
[
(self.version).to_bytes(1, byteorder="big"),
(0).to_bytes(1, byteorder="big"),
bytes.fromhex(self.trace_id),
(1).to_bytes(1, byteorder="big"),
bytes.fromhex(self.span_id),
(2).to_bytes(1, byteorder="big"),
(self.trace_options.asByte).to_bytes(1, byteorder="big"),
]
)
@classmethod
def new(cls, transaction_id: str, is_sampled: bool) -> "TraceParent":
return cls(
version=constants.TRACE_CONTEXT_VERSION,
trace_id="%032x" % random.getrandbits(128),
span_id=transaction_id,
trace_options=TracingOptions(recorded=is_sampled),
)
@classmethod
def from_string(
cls, traceparent_string: str, tracestate_string: Optional[str] = None, is_legacy: bool = False
) -> Optional["TraceParent"]:
try:
parts = traceparent_string.split("-")
version, trace_id, span_id, trace_flags = parts[:4]
except ValueError:
logger.debug("Invalid traceparent header format, value %s", traceparent_string)
return
try:
version = int(version, 16)
if version == 255:
raise ValueError()
except ValueError:
logger.debug("Invalid version field, value %s", version)
return
try:
tracing_options = TracingOptions()
tracing_options.asByte = int(trace_flags, 16)
except ValueError:
logger.debug("Invalid trace-options field, value %s", trace_flags)
return
return TraceParent(version, trace_id, span_id, tracing_options, tracestate_string, is_legacy)
@classmethod
def from_headers(
cls,
headers: dict,
header_name: str = constants.TRACEPARENT_HEADER_NAME,
legacy_header_name: str = constants.TRACEPARENT_LEGACY_HEADER_NAME,
tracestate_header_name: str = constants.TRACESTATE_HEADER_NAME,
) -> Optional["TraceParent"]:
tracestate = cls.merge_duplicate_headers(headers, tracestate_header_name)
if header_name in headers:
return TraceParent.from_string(headers[header_name], tracestate, is_legacy=False)
elif legacy_header_name in headers:
return TraceParent.from_string(headers[legacy_header_name], tracestate, is_legacy=False)
else:
return None
@classmethod
def from_binary(cls, data: bytes) -> Optional["TraceParent"]:
if len(data) != 29:
logger.debug("Invalid binary traceparent format, length is %d, should be 29, value %r", len(data), data)
return
if (
int.from_bytes(data[1:2], byteorder="big") != 0
or int.from_bytes(data[18:19], byteorder="big") != 1
or int.from_bytes(data[27:28], byteorder="big") != 2
):
logger.debug("Invalid binary traceparent format, field identifiers not correct, value %r", data)
return
version = int.from_bytes(data[0:1], byteorder="big")
trace_id = str(binascii.hexlify(data[2:18]), encoding="ascii")
span_id = str(binascii.hexlify(data[19:27]), encoding="ascii")
try:
tracing_options = TracingOptions()
tracing_options.asByte = int.from_bytes(data[28:29], byteorder="big")
except ValueError:
logger.debug("Invalid trace-options field, value %r", data[28:29])
return
return TraceParent(version, trace_id, span_id, tracing_options)
@classmethod
def merge_duplicate_headers(cls, headers, key):
"""
HTTP allows multiple values for the same header name. Most WSGI implementations
merge these values using a comma as separator (this has been confirmed for wsgiref,
werkzeug, gunicorn and uwsgi). Other implementations may use containers like
multidict to store headers and have APIs to iterate over all values for a given key.
This method is provided as a hook for framework integrations to provide their own
TraceParent implementation. The implementation should return a single string. Multiple
values for the same key should be merged using a comma as separator.
:param headers: a dict-like header object
:param key: header name
:return: a single string value or None
"""
# this works for all known WSGI implementations
if isinstance(headers, list):
return ",".join([item[1] for item in headers if item[0] == key])
elif not hasattr(headers, "get") and hasattr(headers, "__iter__"):
return ",".join([item[1] for item in headers if item[0] == key])
return headers.get(key)
def _parse_tracestate(self, tracestate) -> Dict[str, str]:
"""
Tracestate can contain data from any vendor, made distinct by vendor
keys. Vendors are comma-separated. The elastic (es) tracestate data is
made up of key:value pairs, separated by semicolons. It is meant to
be parsed into a dict.
tracestate: es=key:value;key:value... , othervendor=<opaque>
Per https://w3c.github.io/trace-context/#tracestate-header-field-values
there can be optional whitespace (OWS) between the comma-separated
list-members.
"""
if not tracestate:
return {}
if "es=" not in tracestate:
return {}
ret = {}
try:
state = re.search(r"(?:,|^)\s*es=([^,]*?)\s*(?:,|$)", tracestate).group(1).split(";")
except IndexError:
return {}
for keyval in state:
if not keyval:
continue
key, _, val = keyval.partition(":")
ret[key] = val
return ret
def _set_tracestate(self):
elastic_value = ";".join(["{}:{}".format(k, v) for k, v in self.tracestate_dict.items()])
# No character validation needed, as we validate in `add_tracestate`. Just validate length.
if len(elastic_value) > 256:
logger.debug("Modifications to TraceState would violate length limits, ignoring.")
raise TraceStateFormatException()
elastic_state = "es={}".format(elastic_value)
if not self.tracestate:
return elastic_state
else:
# Remove es=<stuff> from the tracestate, and add the new es state to the end
otherstate = re.sub(r"(?:,|^)\s*es=([^,]*?)\s*(?:,|$)", "", self.tracestate)
otherstate = otherstate.lstrip(",") # in case `es=` was the first entry
otherstate = re.sub(r",,", ",", otherstate) # remove potential double commas
# No validation of `otherstate` required, since we're downstream. We only need to check `es=`
# since we introduced it, and that validation has already been done at this point.
if otherstate:
return "{},{}".format(otherstate.rstrip(","), elastic_state)
else:
return elastic_state
def add_tracestate(self, key, val) -> None:
"""
Add key/value pair to the tracestate.
We do most of the validation for valid characters here. We have to make
sure none of the reserved separators for tracestate are used in our
key/value pairs, and we also need to check that all characters are
within the valid range. Checking here means we never have to re-check
a pair once set, which saves time in the _set_tracestate() function.
"""
key = str(key)
val = str(val)
for bad in (":", ";", ",", "="):
if bad in key or bad in val:
logger.debug("New tracestate key/val pair contains invalid character '{}', ignoring.".format(bad))
return
for c in itertools.chain(key, val):
# Tracestate spec only allows for characters between ASCII 0x20 and 0x7E
if ord(c) < 0x20 or ord(c) > 0x7E:
logger.debug("Modifications to TraceState would introduce invalid character '{}', ignoring.".format(c))
return
oldval = self.tracestate_dict.pop(key, None)
self.tracestate_dict[key] = val
try:
self.tracestate = self._set_tracestate()
except TraceStateFormatException:
if oldval is not None:
self.tracestate_dict[key] = oldval
else:
self.tracestate_dict.pop(key)
class TracingOptions_bits(ctypes.LittleEndianStructure):
_fields_ = [("recorded", ctypes.c_uint8, 1)]
class TracingOptions(ctypes.Union):
_anonymous_ = ("bit",)
_fields_ = [("bit", TracingOptions_bits), ("asByte", ctypes.c_uint8)]
def __init__(self, **kwargs) -> None:
super(TracingOptions, self).__init__()
for k, v in kwargs.items():
setattr(self, k, v)
def __eq__(self, other):
return self.asByte == other.asByte
def trace_parent_from_string(traceparent_string, tracestate_string=None, is_legacy=False):
"""
This is a wrapper function so we can add traceparent generation to the
public API.
"""
return TraceParent.from_string(traceparent_string, tracestate_string=tracestate_string, is_legacy=is_legacy)
def trace_parent_from_headers(headers):
"""
This is a wrapper function so we can add traceparent generation to the
public API.
"""
return TraceParent.from_headers(headers)
class TraceStateFormatException(Exception):
pass