aws_advanced_python_wrapper/utils/telemetry/open_telemetry.py (129 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union if TYPE_CHECKING: from opentelemetry.util.types import AttributeValue from opentelemetry import context as context_api from opentelemetry import trace from opentelemetry import trace as trace_api from opentelemetry.metrics import (CallbackOptions, Meter, Observation, get_meter) from opentelemetry.sdk.trace import ReadableSpan, Span, StatusCode, Tracer from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryConst, TelemetryContext, TelemetryCounter, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel) logger = Logger(__name__) INSTRUMENTATION_NAME = "aws-advanced-python-wrapper" class OpenTelemetryContext(TelemetryContext): def __init__(self, tracer: Tracer, name: str, trace_level: TelemetryTraceLevel, start_time: Optional[int] = None, link_span: Optional[Span] = None): self._name = name self._tracer = tracer self._span: Optional[Span] self._meter: Meter self._token: Optional[object] = None current_span: Span = trace.get_current_span() # type: ignore is_root = (current_span is None or current_span == trace.INVALID_SPAN) if is_root and trace_level == TelemetryTraceLevel.NESTED: trace_level = TelemetryTraceLevel.TOP_LEVEL links: Sequence[trace_api.Link] = () if trace_level in [TelemetryTraceLevel.FORCE_TOP_LEVEL, TelemetryTraceLevel.TOP_LEVEL]: if link_span is not None: links = [trace.Link(link_span.get_span_context())] else: if not is_root: links = [trace.Link(current_span.get_span_context())] self._span = self._tracer.start_span(self._name, context=context_api.Context(), links=links, start_time=start_time) # type: ignore if not is_root: self.set_attribute(TelemetryConst.TRACE_NAME_ANNOTATION, self._name) ctx = trace.set_span_in_context(self._span) # type: ignore self._token = context_api.attach(ctx) logger.debug("OpenTelemetryContext.TelemetryTraceID", self._name, self._span.get_span_context().trace_id) # type: ignore elif trace_level == TelemetryTraceLevel.NESTED: if link_span is not None: links = [trace.Link(link_span.get_span_context())] self._span = self._tracer.start_span(self._name, links=links, start_time=start_time) # type: ignore ctx = trace.set_span_in_context(self._span) # type: ignore self._token = context_api.attach(ctx) self.set_attribute(TelemetryConst.TRACE_NAME_ANNOTATION, self._name) elif trace_level == TelemetryTraceLevel.NO_TRACE: self._span = None def set_success(self, success: bool): if self._span is not None: self._span.set_status(StatusCode.OK if success else StatusCode.ERROR) def set_attribute(self, key: str, value: AttributeValue): if self._span is not None: self._span.set_attribute(key, value) def set_exception(self, exception: Exception): if self._span is not None and exception is not None: self._span.set_attribute(TelemetryConst.EXCEPTION_TYPE_ANNOTATION, exception.__class__.__name__) self._span.set_attribute(TelemetryConst.EXCEPTION_MESSAGE_ANNOTATION, str(exception)) self._span.record_exception(exception) def get_name(self) -> str: return self._name def close_context(self): if self._token is not None: context_api.detach(self._token) if self._span is not None: self._span.end() @property def tracer(self) -> Tracer: return self._tracer @property def span(self) -> Optional[Span]: return self._span def post_copy(context: OpenTelemetryContext, trace_level: TelemetryTraceLevel): if trace_level == TelemetryTraceLevel.NO_TRACE: return _clone_and_close_context(context, trace_level) def _clone_and_close_context(context: OpenTelemetryContext, trace_level: TelemetryTraceLevel) -> OpenTelemetryContext: if not isinstance(context.span, ReadableSpan): raise RuntimeError(Messages.get("OpenTelemetry.InvalidContext")) clone = OpenTelemetryContext( context.tracer, TelemetryConst.COPY_TRACE_NAME_PREFIX + context.get_name(), trace_level, context.span.start_time, context.span) for key in context.span.attributes: # type: ignore value = context.span.attributes[key] # type: ignore clone.set_attribute(key, value) clone.span.set_status(context.span.status) # type: ignore clone.set_attribute(TelemetryConst.SOURCE_TRACE_ANNOTATION, str(context.span.get_span_context().trace_id)) clone.span.end(context.span.end_time) # type: ignore return clone class OpenTelemetryCounter(TelemetryCounter): def __init__(self, meter: Meter, name: str): self._meter: Meter = meter self._name: str = name self._counter = meter.create_up_down_counter(self._name, unit="1") def add(self, value): self._counter.add(value) def inc(self): self._counter.add(1) def get_name(self): return self._name class OpenTelemetryGauge(TelemetryGauge): def __init__(self, meter: Meter, name: str, callback: Callable[[], Union[float, int]]): self._meter: Meter = meter self.name: str = name self._counter = meter.create_observable_up_down_counter(name, callbacks=[self._callback_observation], unit="1") self._callback: Callable[[], float] = callback def get_name(self): return self.name def _callback_observation(self, options: CallbackOptions): value: Union[int, float] = self._callback() observation = Observation(value) yield observation class OpenTelemetryFactory(TelemetryFactory): def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext: return OpenTelemetryContext(trace.get_tracer(INSTRUMENTATION_NAME), name, trace_level) # type: ignore def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel): if isinstance(context, OpenTelemetryContext): post_copy(context, trace_level) else: raise RuntimeError(Messages.get_formatted("OpenTelemetryFactory.WrongParameterType", type(context))) def create_counter(self, name: str) -> TelemetryCounter: return OpenTelemetryCounter(get_meter(INSTRUMENTATION_NAME), name) def create_gauge(self, name: str, callback: Callable[[], Union[float, int]]): return OpenTelemetryGauge(get_meter(INSTRUMENTATION_NAME), name, callback)