polardb-postgresql-mcp-server/server.py (264 lines of code) (raw):

from starlette.applications import Starlette from mcp.server.sse import SseServerTransport from starlette.requests import Request from starlette.routing import Mount, Route from mcp.server import Server import uvicorn import logging import os import psycopg from psycopg import OperationalError as Error from mcp.types import Resource, ResourceTemplate, Tool, TextContent from pydantic import AnyUrl from dotenv import load_dotenv import asyncio import sqlparse enable_write = False enable_update = False enable_insert = False enable_ddl = False logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s' ) logger = logging.getLogger("polardb-postgresql-mcp-server") VERSION = "0.0.1" def get_db_config(): """Get database configuration from environment variables.""" config = { "host": os.getenv("POLARDB_POSTGRESQL_HOST", "localhost"), "port": int(os.getenv("POLARDB_POSTGRESQL_PORT", "5432")), "user": os.getenv("POLARDB_POSTGRESQL_USER"), "password": os.getenv("POLARDB_POSTGRESQL_PASSWORD"), "dbname": os.getenv("POLARDB_POSTGRESQL_DBNAME"), "application_name": f"polardb-postgresql-mcp-server-{VERSION}" } if not all([config["user"], config["password"], config["dbname"]]): logger.error("Missing required database configuration. Please check environment variables:") logger.error("POLARDB_POSTGRESQL_USER, POLARDB_POSTGRESQL_PASSWORD, and POLARDB_POSTGRESQL_DBNAME are required") raise ValueError("Missing required database configuration") return config # Initialize server app = Server("polardb-postgresql-mcp-server") @app.list_resources() async def list_resources() -> list[Resource]: try: return [ Resource( uri=f"polardb-postgresql://schemas", name="get_schemas", description=" List all schemas for PolarDB PostgreSQL schemas in the current database", mimeType="text/plain" ) ] except Exception as e: logger.error(f"Error listing resources: {str(e)}") raise @app.list_resource_templates() async def list_resource_templates() -> list[ResourceTemplate]: return [ ResourceTemplate( uriTemplate=f"polardb-postgresql://{{schema}}/tables", name="list_tables", description="List all tables in a specific schema", mimeType="text/plain" ), ResourceTemplate( uriTemplate=f"polardb-postgresql://{{schema}}/{{table}}/field", name="table_field_info", description="get the name,type and comment of the field in the table", mimeType="text/plain" ), ResourceTemplate( uriTemplate=f"polardb-postgresql://{{schema}}/{{table}}/data", name="table_data", description="get data from the table,default limit 50 rows", mimeType="text/plain" ) ] @app.read_resource() async def read_resource(uri: AnyUrl) -> str: config = get_db_config() uri_str = str(uri) logger.info(f"Reading resource: {uri_str}") prefix = "polardb-postgresql://" if not uri_str.startswith(prefix): logger.error(f"Invalid URI scheme: {uri_str}") raise ValueError(f"Invalid URI scheme: {uri_str}") try: with psycopg.connect(**config) as conn: conn.autocommit = True with conn.cursor() as cursor: parts = uri_str[len(prefix):].split('/') if len(parts) == 1 and parts[0] == "schemas": #polardb-postgresql://schemas,list all schemas query = """ SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('cron','information_schema', 'pg_bitmapindex','pg_catalog','pg_toast','polar_catalog','polar_feature_utils') ORDER BY schema_name; """ cursor.execute(query) rows = cursor.fetchall() return "\n".join([row[0] for row in rows]) elif len(parts) == 2 and parts[1] == "tables": #polardb-postgresql://{schema}/tables,list all tables in a schema query = f""" SELECT c.relname AS table_name, obj_description(c.oid) AS table_comment FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind = 'r' AND n.nspname = '{parts[0]}' ORDER BY c.relname; """ cursor.execute(query) rows = cursor.fetchall() return "\n".join([f"{row[0]} ({row[1]})" for row in rows]) elif len(parts) == 3 and parts[2] == "field": # polardb-postgresql://{schema}/{table}/field,list all field info(name,type,comment) in a table schema = parts[0] table = parts[1] query = f""" SELECT a.attname AS column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, col_description(a.attrelid, a.attnum) AS column_comment FROM pg_catalog.pg_attribute a WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attrelid = '{schema}.{table}'::regclass ORDER BY a.attnum; """ cursor.execute(query) rows = cursor.fetchall() result = [",".join(map(str, row)) for row in rows] return "\n".join(result) elif len(parts) == 3 and parts[2] == "data": # polardb-postgresql://{schema}/{table}/data,list all data in a table schema = parts[0] table = parts[1] query = f"SELECT * FROM {schema}.{table} LIMIT 50" cursor.execute(query) rows = cursor.fetchall() result = [",".join(map(str, row)) for row in rows] return "\n".join(result) else: raise ValueError(f"Invalid URI: {uri_str}") except Error as e: logger.error(f"Database error: {str(e)}") raise RuntimeError(f"Database error: {str(e)}") @app.list_tools() async def list_tools() -> list[Tool]: """List available PolarDB PostgreSQL tools.""" logger.info("Listing tools...") return [ Tool( name="execute_sql", description="Execute an SQL query on the PolarDB PostgreSQL server", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "The SQL query to execute" } }, "required": ["query"] } ) ] def get_sql_operation_type(sql): """ get sql operation type :param sql: input sql :return: return sql operation type ('INSERT', 'DELETE', 'UPDATE', 'DDL', or 'OTHER') """ parsed = sqlparse.parse(sql) if not parsed: return 'OTHER' #parse sql failed # get first statement statement = parsed[0] # get first keyword first_token = statement.token_first(skip_ws=True, skip_cm=True) if not first_token: return 'OTHER' keyword = first_token.value.upper() # convert to upper case for uniform comparison # judge sql type if keyword == 'INSERT': return 'INSERT' elif keyword == 'DELETE': return 'DELETE' elif keyword == 'UPDATE': return 'UPDATE' elif keyword in ('CREATE', 'ALTER', 'DROP', 'TRUNCATE'): return 'DDL' else: return 'OTHER' def execute_sql(arguments: str) -> str: config = get_db_config() query = arguments.get("query") if not query: raise ValueError("Query is required") operation_type = get_sql_operation_type(query) logger.info(f"SQL operation type: {operation_type}") global enable_write,enable_update,enable_insert,enable_ddl if operation_type == 'INSERT' and not enable_insert: logger.info(f"INSERT operation is not enabled,please check POLARDB_MYSQL_ENABLE_INSERT") return [TextContent(type="text", text=f"INSERT operation is not enabled in current tool")] elif operation_type == 'UPDATE' and not enable_update: logger.info(f"UPDATE operation is not enabled,please check POLARDB_MYSQL_ENABLE_UPDATE") return [TextContent(type="text", text=f"UPDATE operation is not enabled in current tool")] elif operation_type == 'DELETE' and not enable_write: logger.info(f"DELETE operation is not enabled,please check POLARDB_MYSQL_ENABLE_WRITE") return [TextContent(type="text", text=f"DELETE operation is not enabled in current tool")] elif operation_type == 'DDL' and not enable_ddl: logger.info(f"DDL operation is not enabled,please check POLARDB_MYSQL_ENABLE_DDL") return [TextContent(type="text", text=f"DDL operation is not enabled in current tool")] else: logger.info(f"will Executing SQL: {query}") try: with psycopg.connect(**config) as conn: conn.autocommit = True with conn.cursor() as cursor: cursor.execute(query) if cursor.description is not None: columns = [desc[0] for desc in cursor.description] rows = cursor.fetchall() result = [",".join(map(str, row)) for row in rows] return [TextContent(type="text", text="\n".join([",".join(columns)] + result))] else: conn.commit() return [TextContent(type="text", text=f"Query executed successfully")] except Error as e: logger.error(f"Error executing SQL '{query}': {e}") return [TextContent(type="text", text=f"Error executing query: {str(e)}")] @app.call_tool() async def call_tool(name: str, arguments: dict) -> list[TextContent]: logger.info(f"Calling tool: {name} with arguments: {arguments}") if name == "execute_sql": return execute_sql(arguments) else: raise ValueError(f"Unknown tool: {name}") def create_starlette_app(app: Server, *, debug: bool = False) -> Starlette: """Create a Starlette application that can server the provied mcp server with SSE.""" sse = SseServerTransport("/messages/") async def handle_sse(request: Request) -> None: async with sse.connect_sse( request.scope, request.receive, request._send, # noqa: SLF001 ) as (read_stream, write_stream): await app.run( read_stream, write_stream, app.create_initialization_options(), ) return Starlette( debug=debug, routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) def sse_main(bind_host: str="127.0.0.1", bind_port: int = 8082): # Bind SSE request handling to MCP server starlette_app = create_starlette_app(app, debug=True) logger.info(f"Starting MCP SSE server on {bind_host}:{bind_port}/sse") uvicorn.run(starlette_app, host=bind_host, port=bind_port) async def stdio_main(): """Main entry point to run the MCP server.""" from mcp.server.stdio import stdio_server logger.info("Starting PolarDB PostgreSQL MCP server with stdio mode...") config = get_db_config() logger.info(f"Database config: {config['host']}/{config['dbname']} as {config['user']}") async with stdio_server() as (read_stream, write_stream): try: await app.run( read_stream, write_stream, app.create_initialization_options() ) except Exception as e: logger.error(f"Server error: {str(e)}", exc_info=True) raise def get_bool_env(var_name: str, default: bool = False) -> bool: value = os.getenv(var_name) if value is None: return default return value.lower() in ['true', '1', 't', 'y', 'yes'] def main(): load_dotenv() global enable_write,enable_update,enable_insert,enable_ddl enable_write = get_bool_env("POLARDB_POSTGRESQL_ENABLE_WRITE") enable_update = get_bool_env("POLARDB_POSTGRESQL_ENABLE_UPDATE") enable_insert = get_bool_env("POLARDB_POSTGRESQL_ENABLE_INSERT") enable_ddl = get_bool_env("POLARDB_POSTGRESQL_ENABLE_DDL") logger.info(f"enable_write: {enable_write}, enable_update: {enable_update}, enable_insert: {enable_insert}, enable_ddl: {enable_ddl}") if os.getenv("RUN_MODE")=="stdio": asyncio.run(stdio_main()) else: bind_host = os.getenv("SSE_BIND_HOST") bind_port = int(os.getenv("SSE_BIND_PORT")) sse_main(bind_host,bind_port) if __name__ == "__main__": main()