aws_advanced_python_wrapper/utils/telemetry/xray_telemetry.py (82 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING from aws_xray_sdk.core import xray_recorder if TYPE_CHECKING: from opentelemetry.util.types import AttributeValue from aws_xray_sdk.core.models.segment import Segment from aws_xray_sdk.core.models.subsegment import Subsegment from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryConst, TelemetryContext, TelemetryCounter, TelemetryFactory, TelemetryTraceLevel) logger = Logger(__name__) class XRayTelemetryContext(TelemetryContext): def __init__(self, name: str, trace_level: TelemetryTraceLevel): self._name: str = name self._trace_entity: Segment | Subsegment if trace_level in [TelemetryTraceLevel.FORCE_TOP_LEVEL, TelemetryTraceLevel.TOP_LEVEL]: self._trace_entity = xray_recorder.begin_segment(self._name) self.is_segment = True logger.debug("XRayTelemetryContext.TraceID", self._name, self._trace_entity.trace_id) elif trace_level == TelemetryTraceLevel.NESTED: self._trace_entity = xray_recorder.begin_subsegment(self._name) # type: ignore self.set_attribute(TelemetryConst.TRACE_NAME_ANNOTATION, self._name) self.is_segment = False elif trace_level == TelemetryTraceLevel.NO_TRACE: pass def set_success(self, success: bool): if self._trace_entity is not None: self._trace_entity.error = not success def set_attribute(self, key: str, value: AttributeValue): if self._trace_entity is not None: self._trace_entity.put_annotation(key, value) def set_exception(self, exception: Exception): if self._trace_entity is not None and exception is not None: self._trace_entity.put_annotation("exception_type", exception.__class__.__name__) self._trace_entity.put_annotation("exception_message", str(exception)) def get_name(self): return self._name def close_context(self, end_time=None): if self._trace_entity is not None: if self.is_segment: xray_recorder.end_segment(end_time) else: xray_recorder.end_subsegment(end_time) def post_copy(context: XRayTelemetryContext, trace_level: TelemetryTraceLevel): if trace_level == TelemetryTraceLevel.NO_TRACE: return if trace_level in [TelemetryTraceLevel.FORCE_TOP_LEVEL, TelemetryTraceLevel.TOP_LEVEL]: with ThreadPoolExecutor() as executor: future = executor.submit(_clone_and_close_context, context, trace_level) future.result() else: _clone_and_close_context(context, trace_level) def _clone_and_close_context(context: XRayTelemetryContext, trace_level: TelemetryTraceLevel) -> XRayTelemetryContext: clone = XRayTelemetryContext(TelemetryConst.COPY_TRACE_NAME_PREFIX + context.get_name(), trace_level) clone._trace_entity.start_time = context._trace_entity.start_time for key in context._trace_entity.annotations.items(): value = context._trace_entity.annotations[key] if key != TelemetryConst.TRACE_NAME_ANNOTATION and value is not None: clone.set_attribute(key, value) if context.is_segment and context._trace_entity.error: clone._trace_entity.add_error_flag() clone.set_attribute(TelemetryConst.SOURCE_TRACE_ANNOTATION, str(context._trace_entity.trace_id)) if context._trace_entity.parent_id is not None: if trace_level == TelemetryTraceLevel.NESTED: clone._trace_entity.parent_id = context._trace_entity.parent_id clone.close_context(context._trace_entity.end_time) return clone class XRayTelemetryFactory(TelemetryFactory): def open_telemetry_context(self, name: str, trace_level: TelemetryTraceLevel) -> TelemetryContext: return XRayTelemetryContext(name, trace_level) def post_copy(self, context: TelemetryContext, trace_level: TelemetryTraceLevel): if isinstance(context, XRayTelemetryContext): post_copy(context, trace_level) else: raise RuntimeError(Messages.get_formatted("XRayTelemetryFactory.WrongParameterType", type(context))) def create_counter(self, name: str) -> TelemetryCounter: raise RuntimeError(Messages.get_formatted("XRayTelemetryFactory.MetricsNotSupported")) def create_gauge(self, name: str, callback): raise RuntimeError(Messages.get_formatted("XRayTelemetryFactory.MetricsNotSupported"))