debug.py (82 lines of code) (raw):
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to add more debug information to content streams."""
import asyncio
from collections.abc import AsyncIterable
import time
from absl import logging
from genai_processors import content_api
from genai_processors import processor
ProcessorPart = content_api.ProcessorPart
class TTFTSingleStream(processor.Processor):
"""Wraps a processor to provide performance messaging."""
def __init__(self, message: str, p: processor.Processor):
"""Wraps a processor to provide performance messaging.
Should only be used for processors that consume their entire input before
producing output (such as non-streaming or unidirectional/single streaming
model calls). The TTFT is estimated by waiting first that the inputs
stream is
completely sent to the processor (`start` time is then set). When the
processor outputs its first token, the duration from `start` is then
reported.
In a bidirectional streaming setup, the TTFT will not be reported at all.
Args:
message: header of the status chunk that will be returned. It is used to
identify different calls to this function.
p: processor for which we need to compute ttft. self._message = message
self._p = p self._start = None self._ttft = None self._model_call_event
= asyncio.Event() self._model_call_event.clear()
"""
self._message = message
self._p = p
self._start = None
self._ttft = None
self._model_call_event = asyncio.Event()
self._model_call_event.clear()
def model_call_event(self) -> asyncio.Event:
"""Returns an event that is set when the wrapped processor has all parts.
The event is set when the wrapped processor has all the input parts and
is about to start generating the output.
The event starts in a cleared state when the first part of the input
stream is yielded. It is also cleared at the end of the wrappedprocessor,
when all the output parts have been yielded.
Its default value is unset and this event is set only for a short time
during the call.
Returns:
An event that is set when the model call is started, that is when all the
input parts have been sent to the wrapped processor.
"""
return self._model_call_event
def _ttft_processor(self) -> processor.Processor:
@processor.processor_function
async def log_on_close(
content: AsyncIterable[ProcessorPart],
) -> AsyncIterable[ProcessorPart]:
self._model_call_event.clear()
async for part in content:
yield part
self._start = time.perf_counter()
self._model_call_event.set()
logging.info('ttft single stream start time: %s', self._start)
@processor.processor_function
async def log_on_first(
content: AsyncIterable[ProcessorPart],
) -> AsyncIterable[ProcessorPart]:
first_part = True
async for part in content:
if first_part and self._start is not None:
duration = time.perf_counter() - self._start
self._ttft = duration
self._message += f' TTFT={duration:.2f} seconds'
yield processor.status(ProcessorPart(self._message))
first_part = False
yield part
return log_on_close + self._p + log_on_first
def ttft(self) -> float | None:
"""Returns the TTFT of the wrapped processor.
Returns:
the TTFT of the wrapped processor or None if the processor has not been
called yet.
"""
return self._ttft
async def call(
self, content: AsyncIterable[ProcessorPart]
) -> AsyncIterable[ProcessorPart]:
async for chunk in self._ttft_processor()(content):
yield chunk
def debug_string(part: ProcessorPart) -> str:
return f'{part} role {part.role} substream {part.substream_name}'
def log_stream(message: str) -> processor.Processor:
"""Return a function that logs every part of a stream."""
@processor.processor_function
async def p(
content: AsyncIterable[ProcessorPart],
) -> AsyncIterable[ProcessorPart]:
async for part in content:
logging.info('%s: %s', message, debug_string(part))
yield part
logging.info('%s: done', message)
return p
def print_stream(message: str) -> processor.Processor:
"""Return a function that prints every part of a stream."""
@processor.processor_function
async def p(
content: AsyncIterable[ProcessorPart],
) -> AsyncIterable[ProcessorPart]:
async for part in content:
print(f'{message}: {debug_string(part)}')
yield part
return p
def log_queue(
message: str, queue: asyncio.Queue[ProcessorPart | None]
) -> asyncio.Queue[ProcessorPart | None]:
"""Return a function that logs every part of a queue."""
output_queue = asyncio.Queue()
async def log_and_output():
while (part := await queue.get()) is not None:
queue.task_done()
logging.info('%s: %s', message, debug_string(part))
output_queue.put_nowait(part)
output_queue.put_nowait(None)
processor.create_task(log_and_output())
return output_queue