core/rate_limit_audio.py (101 lines of code) (raw):

# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Rate limiter for audio output.""" import asyncio import logging import math import time from typing import AsyncIterable, Iterable, Optional from absl import logging from genai_processors import content_api from genai_processors import context as context_lib from genai_processors import processor ProcessorPart = content_api.ProcessorPart # Maximum audio chunk/part duration in seconds. MAX_AUDIO_PART_SEC = 0.05 # Buffer in seconds to avoid the audio being cut off. IN_FLIGHT_AUDIO_BUFFER_SEC = 0.05 def _audio_duration(audio_data: bytes, sample_rate: int) -> float: """Returns the duration of the audio data in seconds.""" # 2 bytes per sample (16bits) return len(audio_data) / (2 * sample_rate) def split_audio( audio_data: bytes, sample_rate: int, max_duration_sec: float = MAX_AUDIO_PART_SEC, ) -> Iterable[bytes]: """Splits audio data into chunks of max_duration_sec.""" audio_data_length = len(audio_data) # 2 bytes per sample (16bits) chunk_target_bytes = int(max_duration_sec * sample_rate * 2) num_chunks = math.ceil(audio_data_length / chunk_target_bytes) for i in range(num_chunks): start = i * chunk_target_bytes end = min((i + 1) * chunk_target_bytes, audio_data_length) if start >= end: continue yield audio_data[start:end] class RateLimitAudio(processor.Processor): """Splits and rate-limits the input audio parts for streaming audio output. Gemini API clients are expected to play streaming audio content to the user in its natural playback speed. As all audio parts are streamed at once, the client needs to stop playing back the audio when the user interrupts it. This processor does three things to address that: * Parts of potentially long streaming audio content are split into sub-parts of no more than 200 milliseconds. (Non-streaming audio is left alone, and count as "other parts" for the purposes of this processor.) * Parts are yielded from this processor at the rate of their natural playing speed, to put a reasonably tight limit on the amount of audio buffered beyond the agent's control. * Other parts are passed through unchanged. Debug/status parts are passed through as soon as possible, overtaking audio if needed. """ def __init__(self, sample_rate: int, delay_other_parts: bool = True): """Initializes the rate limiter. Args: sample_rate: The sample rate of the audio. A typical value is 24000 (24KHz) delay_other_parts: If true, other parts will be delayed until the audio is played out. If false, other parts will be passed through as soon as possible, overtaking audio if needed. """ self._sample_rate = sample_rate self._delay_other_parts = delay_other_parts async def call( self, content: AsyncIterable[ProcessorPart] ) -> AsyncIterable[ProcessorPart]: """Rate limits audio output.""" # Most inputs queue here. When full, the fast-tracking of status/debug # chunks starts to block, so let's be generous with the queue size. audio_queue = asyncio.Queue[Optional[ProcessorPart]](10_000) # Delays in outputting from this queue distort the time estimations for # audio sub-chunks, so let's bound its size tightly. output_queue = asyncio.Queue[Optional[ProcessorPart]](3) async def consume_content(): async for part in content: if content_api.is_audio(part.mimetype): # Split the audio into small parts so that when we interrupt between # them, we don't have to wait too long before interrupting. if ( part.part.inline_data is not None and _audio_duration(part.bytes, self._sample_rate) > 2 * MAX_AUDIO_PART_SEC ): for sub_part in split_audio( part.part.inline_data.data, self._sample_rate ): audio_queue.put_nowait( ProcessorPart(sub_part, mimetype=part.mimetype) ) else: audio_queue.put_nowait(part) elif part.get_metadata("interrupted"): logging.debug( "%s - Interrupted - flush audio queue", time.perf_counter() ) # Flush the audio queue - stop rate limiting audio asap. while not audio_queue.empty(): audio_queue.get_nowait() self._audio_duration = 0.0 audio_queue.put_nowait(part) elif ( not self._delay_other_parts or part.substream_name in context_lib.get_reserved_substreams() ): await output_queue.put(part) await asyncio.sleep(0) # Allow `yield` from output_queue to run else: await audio_queue.put(part) await audio_queue.put(None) async def consume_audio(): start_playing_time = self._perf_counter() - 3600 # 1h back. while part := await audio_queue.get(): if content_api.is_audio(part.mimetype): start_playing_time = max( self._perf_counter() - 0.05, start_playing_time ) # Remove the 0.05 seconds delay to avoid the audio being cut off sleep_sec = max(0, start_playing_time - self._perf_counter()) if sleep_sec > 1e-3: await self._asyncio_sleep(sleep_sec) await output_queue.put(part) await asyncio.sleep(0) # Allow `yield` from output_queue to run start_playing_time += _audio_duration( part.part.inline_data.data, self._sample_rate ) else: # Wait for the audio to be played out before passing on to the next # non-audio part. await self._asyncio_sleep( max(0, start_playing_time - self._perf_counter()) ) await output_queue.put(part) await output_queue.put(None) consume_audio_task = processor.create_task(consume_audio()) consume_content_task = processor.create_task(consume_content()) while part := await output_queue.get(): yield part consume_content_task.cancel() consume_audio_task.cancel() # The following wrappers allow unit-tests to mock out walltime. def _perf_counter(self): return time.perf_counter() async def _asyncio_sleep(self, delay: float) -> None: await asyncio.sleep(delay)