agents/live_api/app/server.py (120 lines of code) (raw):
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
from collections.abc import Callable
from typing import Any, Literal
import backoff
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from google.cloud import logging as google_cloud_logging
from google.genai import types
from google.genai.types import LiveServerToolCall
from pydantic import BaseModel
from websockets.exceptions import ConnectionClosedError
from app.agent import MODEL_ID, genai_client, live_connect_config, tool_functions
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
logging_client = google_cloud_logging.Client()
logger = logging_client.logger(__name__)
logging.basicConfig(level=logging.INFO)
class GeminiSession:
"""Manages bidirectional communication between a client and the Gemini model."""
def __init__(
self, session: Any, websocket: WebSocket, tool_functions: dict[str, Callable]
) -> None:
"""Initialize the Gemini session.
Args:
session: The Gemini session
websocket: The client websocket connection
user_id: Unique identifier for this client
tool_functions: Dictionary of available tool functions
"""
self.session = session
self.websocket = websocket
self.run_id = "n/a"
self.user_id = "n/a"
self.tool_functions = tool_functions
async def receive_from_client(self) -> None:
"""Listen for and process messages from the client.
Continuously receives messages and forwards audio data to Gemini.
Handles connection errors gracefully.
"""
while True:
try:
data = await self.websocket.receive_json()
if isinstance(data, dict) and (
"realtimeInput" in data or "clientContent" in data
):
await self.session._ws.send(json.dumps(data))
elif "setup" in data:
self.run_id = data["setup"]["run_id"]
self.user_id = data["setup"]["user_id"]
logger.log_struct(
{**data["setup"], "type": "setup"}, severity="INFO"
)
else:
logging.warning(f"Received unexpected input from client: {data}")
except ConnectionClosedError as e:
logging.warning(f"Client {self.user_id} closed connection: {e}")
break
except Exception as e:
logging.error(f"Error receiving from client {self.user_id}: {e!s}")
break
def _get_func(self, action_label: str) -> Callable | None:
"""Get the tool function for a given action label."""
return None if action_label == "" else self.tool_functions.get(action_label)
async def _handle_tool_call(
self, session: Any, tool_call: LiveServerToolCall
) -> None:
"""Process tool calls from Gemini and send back responses.
Args:
session: The Gemini session
tool_call: Tool call request from Gemini
"""
for fc in tool_call.function_calls:
logging.debug(f"Calling tool function: {fc.name} with args: {fc.args}")
response = self._get_func(fc.name)(**fc.args)
tool_response = types.LiveClientToolResponse(
function_responses=[
types.FunctionResponse(name=fc.name, id=fc.id, response=response)
]
)
logging.debug(f"Tool response: {tool_response}")
await session.send(input=tool_response)
async def receive_from_gemini(self) -> None:
"""Listen for and process messages from Gemini.
Continuously receives messages from Gemini, forwards them to the client,
and handles any tool calls. Handles connection errors gracefully.
"""
while result := await self.session._ws.recv(decode=False):
await self.websocket.send_bytes(result)
raw_message = json.loads(result)
if "toolCall" in raw_message:
message = types.LiveServerMessage.model_validate(raw_message)
tool_call = LiveServerToolCall.model_validate(message.tool_call)
await self._handle_tool_call(self.session, tool_call)
def get_connect_and_run_callable(websocket: WebSocket) -> Callable:
"""Create a callable that handles Gemini connection with retry logic.
Args:
websocket: The client websocket connection
Returns:
Callable: An async function that establishes and manages the Gemini connection
"""
async def on_backoff(details: backoff._typing.Details) -> None:
await websocket.send_json(
{
"status": f"Model connection error, retrying in {details['wait']} seconds..."
}
)
@backoff.on_exception(
backoff.expo, ConnectionClosedError, max_tries=10, on_backoff=on_backoff
)
async def connect_and_run() -> None:
async with genai_client.aio.live.connect(
model=MODEL_ID, config=live_connect_config
) as session:
await websocket.send_json({"status": "Backend is ready for conversation"})
gemini_session = GeminiSession(
session=session, websocket=websocket, tool_functions=tool_functions
)
logging.info("Starting bidirectional communication")
await asyncio.gather(
gemini_session.receive_from_client(),
gemini_session.receive_from_gemini(),
)
return connect_and_run
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
"""Handle new websocket connections."""
await websocket.accept()
connect_and_run = get_connect_and_run_callable(websocket)
await connect_and_run()
class Feedback(BaseModel):
"""Represents feedback for a conversation."""
score: int | float
text: str | None = ""
run_id: str
user_id: str | None
log_type: Literal["feedback"] = "feedback"
@app.post("/feedback")
async def collect_feedback(feedback_dict: Feedback) -> None:
"""Collect and log feedback."""
feedback_data = feedback_dict.model_dump()
logger.log_struct(feedback_data, severity="INFO")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")