gemini/mcp/adk_mcp_app/main.py (95 lines of code) (raw):

import asyncio import json import logging import os from pathlib import Path from dotenv import load_dotenv from fastapi import FastAPI, WebSocket from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from google.adk.agents.llm_agent import LlmAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset, StdioServerParameters from google.genai import types from starlette.websockets import WebSocketDisconnect load_dotenv() APP_NAME = "ADK MCP example" session_service = InMemorySessionService() artifacts_service = InMemoryArtifactService() async def get_tools_async(server_params): """Gets tools from MCP Server.""" tools, exit_stack = await MCPToolset.from_server(connection_params=server_params) # MCP requires maintaining a connection to the local MCP Server. # Using exit_stack to clean up server connection before exit. return tools, exit_stack async def get_agent_async(server_params): """Creates an ADK Agent with tools from MCP Server.""" tools, exit_stack = await get_tools_async(server_params) root_agent = LlmAgent( model="gemini-2.5-pro-preview-03-25", name="ai_assistant", instruction="You're a helpful assistant. Use tools to get information to answer user questions, please format your answer in markdown format.", tools=tools, ) return root_agent, exit_stack ct_server_params = StdioServerParameters( command="python", args=["./mcp_server/cocktail.py"], ) async def run_agent(server_params, session_id, question): query = question print("[user]: ", query) content = types.Content(role="user", parts=[types.Part(text=query)]) root_agent, exit_stack = await get_agent_async(server_params) runner = Runner( app_name=APP_NAME, agent=root_agent, artifact_service=artifacts_service, session_service=session_service, ) events_async = runner.run_async( session_id=session_id, user_id=session_id, new_message=content ) response = [] async for event in events_async: if event.content.role == "model" and event.content.parts[0].text: print("[agent]:", event.content.parts[0].text) response.append(event.content.parts[0].text) await exit_stack.aclose() return response async def run_adk_agent_async(websocket, server_params, session_id): """Client to agent communication""" try: # Your existing setup for the agent might be here logging.info(f"Agent task started for session {session_id}") while True: text = await websocket.receive_text() response = await run_agent(server_params, session_id, text) if not response: continue # Send the text to the client ai_message = "\n".join(response) await websocket.send_text(json.dumps({"message": ai_message})) await asyncio.sleep(0) except WebSocketDisconnect: # This block executes when the client disconnects logging.info(f"Client {session_id} disconnected.") except Exception as e: # Catch other potential errors in your agent logic logging.error( f"Error in agent task for session {session_id}: {e}", exc_info=True ) finally: logging.info(f"Agent task ending for session {session_id}") # FastAPI web app app = FastAPI() STATIC_DIR = Path("static") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.get("/") async def root(): """Serves the index.html""" return FileResponse(os.path.join(STATIC_DIR, "index.html")) @app.websocket("/ws/{session_id}") async def websocket_endpoint(websocket: WebSocket, session_id: int): """Client websocket endpoint""" # Wait for client connection await websocket.accept() print(f"Client #{session_id} connected") # Start agent session session_id = str(session_id) session = session_service.create_session( app_name=APP_NAME, user_id=session_id, session_id=session_id, state={} ) # Start tasks agent_task = asyncio.create_task( run_adk_agent_async(websocket, ct_server_params, session_id) ) await asyncio.gather(agent_task) # Disconnected print(f"Client #{session_id} disconnected")