debug.py (82 lines of code) (raw):

# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities to add more debug information to content streams.""" import asyncio from collections.abc import AsyncIterable import time from absl import logging from genai_processors import content_api from genai_processors import processor ProcessorPart = content_api.ProcessorPart class TTFTSingleStream(processor.Processor): """Wraps a processor to provide performance messaging.""" def __init__(self, message: str, p: processor.Processor): """Wraps a processor to provide performance messaging. Should only be used for processors that consume their entire input before producing output (such as non-streaming or unidirectional/single streaming model calls). The TTFT is estimated by waiting first that the inputs stream is completely sent to the processor (`start` time is then set). When the processor outputs its first token, the duration from `start` is then reported. In a bidirectional streaming setup, the TTFT will not be reported at all. Args: message: header of the status chunk that will be returned. It is used to identify different calls to this function. p: processor for which we need to compute ttft. self._message = message self._p = p self._start = None self._ttft = None self._model_call_event = asyncio.Event() self._model_call_event.clear() """ self._message = message self._p = p self._start = None self._ttft = None self._model_call_event = asyncio.Event() self._model_call_event.clear() def model_call_event(self) -> asyncio.Event: """Returns an event that is set when the wrapped processor has all parts. The event is set when the wrapped processor has all the input parts and is about to start generating the output. The event starts in a cleared state when the first part of the input stream is yielded. It is also cleared at the end of the wrappedprocessor, when all the output parts have been yielded. Its default value is unset and this event is set only for a short time during the call. Returns: An event that is set when the model call is started, that is when all the input parts have been sent to the wrapped processor. """ return self._model_call_event def _ttft_processor(self) -> processor.Processor: @processor.processor_function async def log_on_close( content: AsyncIterable[ProcessorPart], ) -> AsyncIterable[ProcessorPart]: self._model_call_event.clear() async for part in content: yield part self._start = time.perf_counter() self._model_call_event.set() logging.info('ttft single stream start time: %s', self._start) @processor.processor_function async def log_on_first( content: AsyncIterable[ProcessorPart], ) -> AsyncIterable[ProcessorPart]: first_part = True async for part in content: if first_part and self._start is not None: duration = time.perf_counter() - self._start self._ttft = duration self._message += f' TTFT={duration:.2f} seconds' yield processor.status(ProcessorPart(self._message)) first_part = False yield part return log_on_close + self._p + log_on_first def ttft(self) -> float | None: """Returns the TTFT of the wrapped processor. Returns: the TTFT of the wrapped processor or None if the processor has not been called yet. """ return self._ttft async def call( self, content: AsyncIterable[ProcessorPart] ) -> AsyncIterable[ProcessorPart]: async for chunk in self._ttft_processor()(content): yield chunk def debug_string(part: ProcessorPart) -> str: return f'{part} role {part.role} substream {part.substream_name}' def log_stream(message: str) -> processor.Processor: """Return a function that logs every part of a stream.""" @processor.processor_function async def p( content: AsyncIterable[ProcessorPart], ) -> AsyncIterable[ProcessorPart]: async for part in content: logging.info('%s: %s', message, debug_string(part)) yield part logging.info('%s: done', message) return p def print_stream(message: str) -> processor.Processor: """Return a function that prints every part of a stream.""" @processor.processor_function async def p( content: AsyncIterable[ProcessorPart], ) -> AsyncIterable[ProcessorPart]: async for part in content: print(f'{message}: {debug_string(part)}') yield part return p def log_queue( message: str, queue: asyncio.Queue[ProcessorPart | None] ) -> asyncio.Queue[ProcessorPart | None]: """Return a function that logs every part of a queue.""" output_queue = asyncio.Queue() async def log_and_output(): while (part := await queue.get()) is not None: queue.task_done() logging.info('%s: %s', message, debug_string(part)) output_queue.put_nowait(part) output_queue.put_nowait(None) processor.create_task(log_and_output()) return output_queue