agents/utils/connections.py (108 lines of code) (raw):

"""Connection handling for MCP servers.""" from abc import ABC, abstractmethod from contextlib import AsyncExitStack from typing import Any from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from ..tools.mcp_tool import MCPTool class MCPConnection(ABC): """Base class for MCP server connections.""" def __init__(self): self.session = None self._rw_ctx = None self._session_ctx = None @abstractmethod async def _create_rw_context(self): """Create the read/write context based on connection type.""" async def __aenter__(self): """Initialize MCP server connection.""" self._rw_ctx = await self._create_rw_context() read_write = await self._rw_ctx.__aenter__() read, write = read_write self._session_ctx = ClientSession(read, write) self.session = await self._session_ctx.__aenter__() await self.session.initialize() return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Clean up MCP server connection resources.""" try: if self._session_ctx: await self._session_ctx.__aexit__(exc_type, exc_val, exc_tb) if self._rw_ctx: await self._rw_ctx.__aexit__(exc_type, exc_val, exc_tb) except Exception as e: print(f"Error during cleanup: {e}") finally: self.session = None self._session_ctx = None self._rw_ctx = None async def list_tools(self) -> Any: """Retrieve available tools from the MCP server.""" response = await self.session.list_tools() return response.tools async def call_tool( self, tool_name: str, arguments: dict[str, Any] ) -> Any: """Call a tool on the MCP server with provided arguments.""" return await self.session.call_tool(tool_name, arguments=arguments) class MCPConnectionStdio(MCPConnection): """MCP connection using standard input/output.""" def __init__( self, command: str, args: list[str] = [], env: dict[str, str] = None ): super().__init__() self.command = command self.args = args self.env = env async def _create_rw_context(self): return stdio_client( StdioServerParameters( command=self.command, args=self.args, env=self.env ) ) class MCPConnectionSSE(MCPConnection): """MCP connection using Server-Sent Events.""" def __init__(self, url: str, headers: dict[str, str] = None): super().__init__() self.url = url self.headers = headers or {} async def _create_rw_context(self): return sse_client(url=self.url, headers=self.headers) def create_mcp_connection(config: dict[str, Any]) -> MCPConnection: """Factory function to create the appropriate MCP connection.""" conn_type = config.get("type", "stdio").lower() if conn_type == "stdio": if not config.get("command"): raise ValueError("Command is required for STDIO connections") return MCPConnectionStdio( command=config["command"], args=config.get("args"), env=config.get("env"), ) elif conn_type == "sse": if not config.get("url"): raise ValueError("URL is required for SSE connections") return MCPConnectionSSE( url=config["url"], headers=config.get("headers") ) else: raise ValueError(f"Unsupported connection type: {conn_type}") async def setup_mcp_connections( mcp_servers: list[dict[str, Any]] | None, stack: AsyncExitStack, ) -> list[MCPTool]: """Set up MCP server connections and create tool interfaces.""" if not mcp_servers: return [] mcp_tools = [] for config in mcp_servers: try: connection = create_mcp_connection(config) await stack.enter_async_context(connection) tool_definitions = await connection.list_tools() for tool_info in tool_definitions: mcp_tools.append( MCPTool( name=tool_info.name, description=tool_info.description or f"MCP tool: {tool_info.name}", input_schema=tool_info.inputSchema, connection=connection, ) ) except Exception as e: print(f"Error setting up MCP server {config}: {e}") print( f"Loaded {len(mcp_tools)} MCP tools from {len(mcp_servers)} servers." ) return mcp_tools