core/live_model.py (184 lines of code) (raw):

# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Wraps the Gemini Live API into a Processor. ## Example usage ```py p = LiveProcessor( api_key=API_KEY, model_name="gemini-2.0-flash-live-001", ) ``` Given a stream of content `input_stream`, you can use the processor to generate a stream of content as follows: ```py input_stream = processor.stream_content( [ processor.ProcessorPart('hello'), processor.ProcessorPart('world') ] ) async for part in p(input_stream): # Do something with part ``` The Live Processor processor only considers parts with the substream name "realtime" as input (sent to real-time methods), or with the default substream name (sent to content generate method). """ import asyncio from collections.abc import AsyncIterable import re import time from typing import Iterable, Optional from absl import logging from genai_processors import content_api from genai_processors import processor from google.genai import client from google.genai import types as genai_types PartTypes = content_api.ProcessorPartTypes ProcessorPart = content_api.ProcessorPart def to_parts( msg: genai_types.LiveServerMessage, ) -> Iterable[content_api.ProcessorPart]: """Converts a LiveServerMessage to a stream of ProcessorParts.""" if msg.server_content: metadata = msg.server_content.to_json_dict() if "model_turn" in metadata: del metadata["model_turn"] if msg.server_content.model_turn: for part in msg.server_content.model_turn.parts: yield content_api.ProcessorPart( value=part, role=msg.server_content.model_turn.role, ) for k, v in metadata.items(): value = "" if k in ("input_transcription", "output_transcription"): if "text" in v: value = v["text"] yield content_api.ProcessorPart( value=value, role="MODEL", substream_name=k, ) else: yield content_api.ProcessorPart( value="", role="MODEL", metadata={k: v}, ) if msg.tool_call: function_calls = msg.tool_call.function_calls for function_call in function_calls: yield content_api.ProcessorPart.from_function_call( name=function_call.name, args=function_call.args, role="MODEL", metadata={"id": function_call.id}, ) if msg.tool_call_cancellation and msg.tool_call_cancellation.ids: for function_call_id in msg.tool_call_cancellation.ids: yield content_api.ProcessorPart.from_tool_cancellation( function_call_id=function_call_id, ) if msg.usage_metadata: yield content_api.ProcessorPart( value="", role="MODEL", metadata={"usage_metadata": msg.usage_metadata.to_json_dict()}, ) if msg.go_away: yield content_api.ProcessorPart( value="", role="MODEL", metadata={"go_away": msg.go_away.to_json_dict()}, ) if msg.session_resumption_update: yield content_api.ProcessorPart( value="", role="MODEL", metadata={ "session_resumption_update": ( msg.session_resumption_update.to_json_dict() ) }, ) class LiveProcessor(processor.Processor): """Gemini Live API Processor to generate realtime content. The realtime content captured via mic and camera should be passed to the processor with the `realtime` substream name. The default substream is used for standard content input. An image sent on the default substream will be processed by the model as an ad-hoc user input, not as a realtime input captured from realtime devices. This lets the user send an image to the model and ask a question about it, for example "What is this?", independently of the video stream being sent to the model on the `realtime` substream. """ def __init__( self, api_key: str, model_name: str, realtime_config: Optional[genai_types.LiveConnectConfigOrDict] = None, debug_config: client.DebugConfig | None = None, http_options: ( genai_types.HttpOptions | genai_types.HttpOptionsDict | None ) = None, ): """Initializes the Live Processor. Args: api_key: The [API key](https://ai.google.dev/gemini-api/docs/api-key) to use for authentication. Applies to the Gemini Developer API only. model_name: The name of the model to use. See https://ai.google.dev/gemini-api/docs/models for a list of available models. Only use models with a `-live-` suffix. realtime_config: The configuration for generating realtime content. debug_config: Config settings that control network behavior of the client. This is typically used when running test code. http_options: Http options to use for the client. These options will be applied to all requests made by the client. Example usage: `client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`. Returns: A `Processor` that calls the Genai API in a realtime (aka live) fashion. """ self._client = client.Client( api_key=api_key, debug_config=debug_config, http_options=http_options, ) self._model_name = model_name self._realtime_config = realtime_config async def call( self, content: AsyncIterable[ProcessorPart] ) -> AsyncIterable[ProcessorPart]: output_queue = asyncio.Queue[Optional[ProcessorPart]](maxsize=1_000) async with self._client.aio.live.connect( model=self._model_name, config=self._realtime_config, ) as session: async def consume_content(): async for chunk_part in content: if chunk_part.part.function_response: logging.debug( "%s - Live Processor: sending tool response: %s", time.perf_counter(), chunk_part, ) await session.send_tool_response( function_responses=chunk_part.part.function_response ) elif ( chunk_part.substream_name == "realtime" and chunk_part.get_metadata("audio_stream_end") ): logging.debug( "%s - Live Processor: sending realtime audio_stream_end", time.perf_counter(), ) await session.send_realtime_input(audio_stream_end=True) elif ( chunk_part.substream_name == "realtime" and chunk_part.part.inline_data ): await session.send_realtime_input(media=chunk_part.part.inline_data) elif chunk_part.substream_name == "realtime" and content_api.is_text( chunk_part.mimetype ): logging.debug( "%s - Live Processor: sending realtime input: %s", time.perf_counter(), chunk_part.text, ) await session.send_realtime_input(text=chunk_part.text) elif not chunk_part.substream_name: # Default substream. logging.debug( "%s - Live Processor: sending client content: %s", time.perf_counter(), chunk_part.part, ) turn_complete = chunk_part.get_metadata("turn_complete") await session.send_client_content( turns=genai_types.Content( parts=[chunk_part.part], role=chunk_part.role ), turn_complete=True if turn_complete is None else turn_complete, ) else: logging.debug( "%s - Live Processor: part passed through: %s", time.perf_counter(), chunk_part, ) await output_queue.put(chunk_part) await output_queue.put(None) async def produce_content(): try: while True: async for response in session.receive(): if not ( response.server_content and response.server_content.model_turn and response.server_content.model_turn.parts and response.server_content.model_turn.parts[0].inline_data ): logging.debug( "%s - Live Processor Response: %s", time.perf_counter(), # Remove the None values from the response. re.sub(r"(,\s)?[^\(\s]+=None,?\s?", "", str(response)), ) for part in to_parts(response): await output_queue.put(part) # Allow `yield` if session.receive() does not return anything. await asyncio.sleep(0) finally: await output_queue.put(None) consume_content_task = processor.create_task(consume_content()) produce_content_task = processor.create_task(produce_content()) while chunk := await output_queue.get(): yield chunk consume_content_task.cancel() produce_content_task.cancel()