azext_edge/edge/providers/stats.py (194 lines of code) (raw):

# coding=utf-8 # ---------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- import binascii import json from datetime import datetime from typing import TYPE_CHECKING, List, Optional, Tuple, Union from azure.cli.core.azclierror import ResourceNotFoundError from knack.log import get_logger from rich.console import Console from ..common import AIO_BROKER_DIAGNOSTICS_SERVICE, PROTOBUF_SERVICE_API_PORT, PodState from ..util import get_timestamp_now_utc from .base import V1Pod, get_namespaced_pods_by_prefix, portforward_socket logger = get_logger(__name__) console = Console(highlight=True) if TYPE_CHECKING: # pylint: disable=no-name-in-module from socket import socket from zipfile import ZipInfo from opentelemetry.proto.trace.v1.trace_pb2 import TracesData def _preprocess_stats( namespace: Optional[str] = None, diag_service_pod_prefix: str = AIO_BROKER_DIAGNOSTICS_SERVICE ) -> Tuple[str, V1Pod]: if not namespace: from .base import DEFAULT_NAMESPACE namespace = DEFAULT_NAMESPACE target_pods = get_namespaced_pods_by_prefix(prefix=diag_service_pod_prefix, namespace=namespace) if not target_pods: raise ResourceNotFoundError( f"Diagnostics service pod '{diag_service_pod_prefix}' does not exist in namespace '{namespace}'." ) for pod in target_pods: if pod.status.phase.lower() == PodState.running.value: return namespace, pod raise ResourceNotFoundError( f"No diagnostics service pod '{diag_service_pod_prefix}' in phase " f"'{PodState.running.value}' detected in namespace '{namespace}'." ) def get_traces( namespace: Optional[str] = None, diag_service_pod_prefix: str = AIO_BROKER_DIAGNOSTICS_SERVICE, pod_protobuf_port: int = PROTOBUF_SERVICE_API_PORT, trace_ids: Optional[List[str]] = None, trace_dir: Optional[str] = None, ) -> Union[List["TracesData"], List[Tuple["ZipInfo", str]], None]: """ trace_ids: List[str] hex representation of trace Ids. """ if not any([trace_ids, trace_dir]): raise ValueError("At least trace_ids or trace_dir is required.") from zipfile import ZIP_DEFLATED, ZipFile, ZipInfo from google.protobuf.json_format import MessageToDict from rich.progress import MofNCompleteColumn, Progress from ..util import normalize_dir # pylint: disable=no-name-in-module from .proto.diagnostics_service_pb2 import Request, Response, TraceRetrievalInfo namespace, diagnostic_pod = _preprocess_stats(namespace=namespace, diag_service_pod_prefix=diag_service_pod_prefix) for_support_bundle = False trace_ids = trace_ids or [] if trace_ids: if trace_ids[0] == "!support_bundle!": trace_ids.pop() for_support_bundle = True trace_ids = [binascii.unhexlify(t) for t in trace_ids] with Progress( *Progress.get_default_columns(), MofNCompleteColumn(), transient=False, disable=bool(trace_ids) or for_support_bundle, ) as progress: with portforward_socket( namespace=namespace, pod_name=diagnostic_pod.metadata.name, pod_port=pod_protobuf_port ) as socket: request = Request(get_traces=TraceRetrievalInfo(trace_ids=trace_ids)) serialized_request = request.SerializeToString() request_len_b = len(serialized_request).to_bytes(4, byteorder="big") socket.sendall(request_len_b) socket.sendall(serialized_request) traces: List[dict] = [] if trace_dir: normalized_dir_path = normalize_dir(dir_path=trace_dir) normalized_dir_path = normalized_dir_path.joinpath( f"broker_traces_{get_timestamp_now_utc(format='%Y%m%dT%H%M%S')}.zip" ) # pylint: disable=consider-using-with myzip = ZipFile(file=str(normalized_dir_path), mode="w", compression=ZIP_DEFLATED) progress_set = False progress_task = None total_trace_count = 0 current_trace_count = 0 try: while True: if current_trace_count and current_trace_count >= total_trace_count: break rbytes = _fetch_bytes(socket, 4) response_size = int.from_bytes(rbytes, byteorder="big") response_bytes = _fetch_bytes(socket, response_size) if response_bytes == b"": logger.warning("TCP socket closed. Trace processing aborted.") return response = Response.FromString(response_bytes) current_trace_count = current_trace_count + 1 if not total_trace_count: total_trace_count = response.retrieved_trace.total_trace_count if total_trace_count == 0: logger.warning("No traces to fetch. Processing aborted.") break if not progress.disable and not progress_set: progress_task = progress.add_task( "[deep_sky_blue4]Gathering traces...", total=response.retrieved_trace.total_trace_count ) progress_set = True msg_dict = MessageToDict(message=response.retrieved_trace.trace, use_integers_for_enums=True) root_span, resource_name, timestamp = _determine_root_span(message_dict=msg_dict) if progress_set: progress.update(progress_task, advance=1) if not all([root_span, resource_name, timestamp]): logger.debug("Could not process root span. Skipping trace.") continue span_trace_id = root_span["traceId"] span_name = root_span["name"] if trace_ids: traces.append(msg_dict) if trace_dir or for_support_bundle: archive = f"{resource_name}.{span_name}.{span_trace_id}" pb_suffix = ".otlp.pb" tempo_suffix = ".tempo.json" datetime_tuple = tuple(timestamp.timetuple()) zinfo_pb = ZipInfo(filename=f"{archive}{pb_suffix}", date_time=datetime_tuple) # Fixed in Py 3.9 https://github.com/python/cpython/issues/70373 zinfo_pb.file_size = 0 zinfo_pb.compress_size = 0 zinfo_tempo = ZipInfo(filename=f"{archive}{tempo_suffix}", date_time=datetime_tuple) zinfo_tempo.file_size = 0 zinfo_tempo.compress_size = 0 otlp_format_pair = (zinfo_pb, response.retrieved_trace.trace.SerializeToString()) tempo_format_pair = ( zinfo_tempo, json.dumps(_convert_otlp_to_tempo(msg_dict), sort_keys=True), ) if for_support_bundle: traces.append(otlp_format_pair) traces.append(tempo_format_pair) continue # Original OTLP myzip.writestr( zinfo_or_arcname=otlp_format_pair[0], data=otlp_format_pair[1], ) # Tempo myzip.writestr( zinfo_or_arcname=tempo_format_pair[0], data=tempo_format_pair[1], ) if traces: return traces finally: if trace_dir: myzip.close() def _determine_root_span(message_dict: dict) -> Tuple[str, str, Union[datetime, None]]: """ Attempts to determine root span, and normalizes traceId, spanId and parentSpanId to hex. """ import base64 root_span = None resource_name = None timestamp = None for resource_span in message_dict.get("resourceSpans", []): for scope_span in resource_span.get("scopeSpans", []): for span in scope_span.get("spans", []): if "traceId" in span: span["traceId"] = base64.b64decode(span["traceId"]).hex() if "spanId" in span: span["spanId"] = base64.b64decode(span["spanId"]).hex() if "parentSpanId" in span: span["parentSpanId"] = base64.b64decode(span["parentSpanId"]).hex() else: root_span = span if "startTimeUnixNano" in root_span: timestamp_unix_nano = root_span["startTimeUnixNano"] timestamp = datetime.utcfromtimestamp(float(timestamp_unix_nano) / 1e9) # determine resource name resource = resource_span.get("resource", {}) attributes = resource.get("attributes", []) for a in attributes: if a["key"] == "service.name": resource_name = a["value"].get("stringValue", "unknown") return root_span, resource_name, timestamp def _convert_otlp_to_tempo(message_dict: dict) -> dict: """ Convert OTLP payload to Grafana Tempo. """ from copy import deepcopy new_dict = deepcopy(message_dict) new_dict["batches"] = new_dict.pop("resourceSpans") for batch in new_dict.get("batches", []): batch["instrumentationLibrarySpans"] = batch.pop("scopeSpans", []) for inst_lib_span in batch.get("instrumentationLibrarySpans", []): inst_lib_span["instrumentationLibrary"] = inst_lib_span.pop("scope", {}) return new_dict def _fetch_bytes(socket: "socket", size: int) -> bytes: result_bytes = socket.recv(size) if result_bytes == b"": return result_bytes result_bytes_len = len(result_bytes) while result_bytes_len < size: remaining_bytes_size = size - result_bytes_len interm_bytes = socket.recv(remaining_bytes_size) if interm_bytes == b"": break result_bytes += interm_bytes result_bytes_len += len(interm_bytes) return result_bytes