router/pkg/mcpserver/server.go (528 lines of code) (raw):
package mcpserver
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/hashicorp/go-retryablehttp"
"github.com/iancoleman/strcase"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/santhosh-tekuri/jsonschema/v6"
"github.com/wundergraph/cosmo/router/pkg/schemaloader"
"github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
"github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter"
"go.uber.org/zap"
)
// authKey is a custom context key for storing the auth token.
type authKey struct{}
// withAuthKey adds an auth key to the context.
func withAuthKey(ctx context.Context, auth string) context.Context {
return context.WithValue(ctx, authKey{}, auth)
}
// authFromRequest extracts the auth token from the request headers.
func authFromRequest(ctx context.Context, r *http.Request) context.Context {
return withAuthKey(ctx, r.Header.Get("Authorization"))
}
// tokenFromContext extracts the auth token from the context.
// This can be used by clients to pass the auth token to the server.
func tokenFromContext(ctx context.Context) (string, error) {
auth, ok := ctx.Value(authKey{}).(string)
if !ok {
return "", fmt.Errorf("missing auth")
}
return auth, nil
}
// Options represents configuration options for the GraphQLSchemaServer
type Options struct {
// GraphName is the name of the graph to be served
GraphName string
// OperationsDir is the directory where GraphQL operations are stored
OperationsDir string
// ListenAddr is the address where the server should listen to
ListenAddr string
// BaseURL of the MCP server. This is the URL advertised to the LLM clients.
// By default, the base URL is relative to the URL that the router is running on.
BaseURL string
// Enabled determines whether the MCP server should be started
Enabled bool
// Logger is the logger to be used
Logger *zap.Logger
// RequestTimeout is the timeout for HTTP requests
RequestTimeout time.Duration
// ExcludeMutations determines whether mutation operations should be excluded
ExcludeMutations bool
// EnableArbitraryOperations determines whether arbitrary GraphQL operations can be executed
EnableArbitraryOperations bool
// ExposeSchema determines whether the GraphQL schema is exposed
ExposeSchema bool
}
// GraphQLSchemaServer represents an MCP server that works with GraphQL schemas and operations
type GraphQLSchemaServer struct {
server *server.MCPServer
baseURL string
graphName string
operationsDir string
listenAddr string
logger *zap.Logger
httpClient *http.Client
requestTimeout time.Duration
routerGraphQLEndpoint string
sseServer *server.SSEServer
excludeMutations bool
enableArbitraryOperations bool
exposeSchema bool
operationsManager *OperationsManager
schemaCompiler *SchemaCompiler
registeredTools []string
}
type graphqlRequest struct {
Query string `json:"query"`
Variables json.RawMessage `json:"variables"`
}
// ExecuteGraphQLInput defines the input structure for the execute_graphql tool
type ExecuteGraphQLInput struct {
Query string `json:"query"`
Variables json.RawMessage `json:"variables,omitempty"`
}
// operationHandler holds an operation and its compiled JSON schema
type operationHandler struct {
operation schemaloader.Operation
compiledSchema *jsonschema.Schema
}
// OperationInfo contains metadata about a GraphQL operation
type OperationInfo struct {
Name string `json:"name"`
Description string `json:"description"`
Schema json.RawMessage `json:"schema,omitempty"`
Query string `json:"query"`
}
// OperationsResponse is the response structure for the listGraphQLOperations tool
type OperationsResponse struct {
Operations []OperationInfo `json:"operations"`
Usage string `json:"usage"`
LLMGuidance LLMGuidance `json:"llmGuidance"`
Endpoint string `json:"endpoint"`
}
// GraphQLOperationInfoResponse is the response structure for the graphql_operation_info tool.
type GraphQLOperationInfoResponse struct {
Name string `json:"name"`
Description string `json:"description"`
OperationType string `json:"operationType"`
HasSideEffects bool `json:"hasSideEffects"`
Schema json.RawMessage `json:"schema,omitempty"`
Query string `json:"query"`
LLMGuidance LLMGuidance `json:"llmGuidance"`
Endpoint string `json:"endpoint"`
}
// GraphQLOperationInfoInput defines the input structure for the graphql_operation_info tool.
type GraphQLOperationInfoInput struct {
OperationName string `json:"operationName"`
}
// LLMGuidance provides guidance for LLMs on how to use the GraphQL operations
type LLMGuidance struct {
HTTPUsage string `json:"httpUsage"`
GraphQLRequest string `json:"graphqlRequest"`
ExecutionTips []string `json:"executionTips"`
}
// GraphQLError represents an error returned in a GraphQL response
type GraphQLError struct {
Message string `json:"message"`
}
// GraphQLResponse represents a GraphQL response structure
type GraphQLResponse struct {
Errors []GraphQLError `json:"errors"`
Data json.RawMessage `json:"data"`
}
// NewGraphQLSchemaServer creates a new GraphQL schema server
func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options)) (*GraphQLSchemaServer, error) {
if routerGraphQLEndpoint == "" {
return nil, fmt.Errorf("routerGraphQLEndpoint cannot be empty")
}
if !strings.Contains(routerGraphQLEndpoint, "://") {
routerGraphQLEndpoint = "http://" + routerGraphQLEndpoint
}
// Default options
options := &Options{
GraphName: "graph",
OperationsDir: "operations",
ListenAddr: "0.0.0.0:5025",
Enabled: false,
Logger: zap.NewNop(),
RequestTimeout: 30 * time.Second,
ExposeSchema: true,
}
// Apply all option functions
for _, opt := range opts {
opt(options)
}
// Create the MCP server
mcpServer := server.NewMCPServer(
"wundergraph-cosmo-"+strcase.ToKebab(options.GraphName),
"0.0.1",
// Prompt, Resources aren't supported yet in any of the popular platforms
server.WithToolCapabilities(true),
server.WithPaginationLimit(100),
server.WithRecovery(),
)
retryClient := retryablehttp.NewClient()
retryClient.Logger = nil
httpClient := retryClient.StandardClient()
httpClient.Timeout = 60 * time.Second
gs := &GraphQLSchemaServer{
server: mcpServer,
graphName: options.GraphName,
operationsDir: options.OperationsDir,
listenAddr: options.ListenAddr,
logger: options.Logger,
httpClient: httpClient,
requestTimeout: options.RequestTimeout,
routerGraphQLEndpoint: routerGraphQLEndpoint,
excludeMutations: options.ExcludeMutations,
enableArbitraryOperations: options.EnableArbitraryOperations,
exposeSchema: options.ExposeSchema,
baseURL: options.BaseURL,
}
return gs, nil
}
// WithGraphName sets the graph name
func WithGraphName(graphName string) func(*Options) {
return func(o *Options) {
o.GraphName = graphName
}
}
// WithOperationsDir sets the operations directory
func WithOperationsDir(operationsDir string) func(*Options) {
return func(o *Options) {
o.OperationsDir = operationsDir
}
}
// WithBaseURL sets the base URL
func WithBaseURL(baseURL string) func(*Options) {
return func(o *Options) {
o.BaseURL = baseURL
}
}
// WithListenAddr sets the listen address
func WithListenAddr(listenAddr string) func(*Options) {
return func(o *Options) {
o.ListenAddr = listenAddr
}
}
func WithLogger(logger *zap.Logger) func(*Options) {
return func(o *Options) {
o.Logger = logger
}
}
// WithExcludeMutations sets the exclude mutations option
func WithExcludeMutations(excludeMutations bool) func(*Options) {
return func(o *Options) {
o.ExcludeMutations = excludeMutations
}
}
// WithEnableArbitraryOperations sets the enable arbitrary operations option
func WithEnableArbitraryOperations(enableArbitraryOperations bool) func(*Options) {
return func(o *Options) {
o.EnableArbitraryOperations = enableArbitraryOperations
}
}
// WithExposeSchema sets the expose schema option
func WithExposeSchema(exposeSchema bool) func(*Options) {
return func(o *Options) {
o.ExposeSchema = exposeSchema
}
}
// ServeSSE starts the server with SSE transport
func (s *GraphQLSchemaServer) ServeSSE() (*server.SSEServer, error) {
sseServer := server.NewSSEServer(s.server,
server.WithBaseURL(s.baseURL),
server.WithSSEEndpoint("/mcp"),
server.WithSSEContextFunc(authFromRequest),
server.WithKeepAlive(true),
server.WithKeepAliveInterval(10*time.Second),
)
logger := []zap.Field{
zap.String("listen_addr", s.listenAddr),
zap.String("path", "/mcp"),
zap.String("operations_dir", s.operationsDir),
zap.String("graph_name", s.graphName),
zap.Bool("exclude_mutations", s.excludeMutations),
zap.Bool("enable_arbitrary_operations", s.enableArbitraryOperations),
zap.Bool("expose_schema", s.exposeSchema),
}
s.logger.Info("MCP server started", logger...)
go func() {
defer s.logger.Info("MCP server stopped")
err := sseServer.Start(s.listenAddr)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Error("failed to start SSE server", zap.Error(err))
}
}()
return sseServer, nil
}
// Start loads operations and starts the server
func (s *GraphQLSchemaServer) Start() error {
sseServer, err := s.ServeSSE()
if err != nil {
return fmt.Errorf("failed to create SSE server: %w", err)
}
s.sseServer = sseServer
return nil
}
// Reload reloads the operations and schema
func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error {
if s.server == nil {
return fmt.Errorf("server is not started")
}
s.schemaCompiler = NewSchemaCompiler(s.logger)
s.operationsManager = NewOperationsManager(schema, s.logger, s.excludeMutations)
if err := s.operationsManager.LoadOperationsFromDirectory(s.operationsDir); err != nil {
return fmt.Errorf("failed to load operations: %w", err)
}
s.server.DeleteTools(s.registeredTools...)
if err := s.registerTools(); err != nil {
return fmt.Errorf("failed to register tools: %w", err)
}
return nil
}
// Stop gracefully shuts down the MCP server
func (s *GraphQLSchemaServer) Stop(ctx context.Context) error {
s.logger.Debug("shutting down MCP server")
// Create a shutdown context with timeout
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := s.sseServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("failed to gracefully shutdown SSE server: %w", err)
}
return nil
}
// registerTools registers all tools for the MCP server
func (s *GraphQLSchemaServer) registerTools() error {
// Only register the schema tool if exposeSchema is enabled
if s.exposeSchema {
s.server.AddTool(
mcp.NewTool(
"get_schema",
mcp.WithDescription("Provides the full GraphQL schema of the API."),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
Title: "Get GraphQL Schema",
ReadOnlyHint: mcp.ToBoolPtr(true),
}),
),
s.handleGetGraphQLSchema(),
)
s.registeredTools = append(s.registeredTools, "get_schema")
}
// Only register the execute_graphql tool if enableArbitraryOperations is enabled
if s.enableArbitraryOperations {
// Add a tool to execute arbitrary GraphQL queries
executeGraphQLSchema := []byte(`{
"type": "object",
"description": "The query and variables to execute.",
"properties": {
"query": {
"type": "string",
"description": "The GraphQL query or mutation string to execute."
},
"variables": {
"type": "object",
"additionalProperties": true,
"description": "The variables to pass to the GraphQL query as a JSON object."
}
},
"additionalProperties": false,
"required": ["query"]
}`)
// Validate the schema before using it
if err := s.schemaCompiler.ValidateJSONSchema(executeGraphQLSchema); err != nil {
return fmt.Errorf("invalid schema for execute_graphql tool: %w", err)
}
tool := mcp.NewToolWithRawSchema(
"execute_graphql",
"Executes a GraphQL query or mutation.",
executeGraphQLSchema,
)
tool.Annotations = mcp.ToolAnnotation{
Title: "Execute GraphQL Query",
DestructiveHint: mcp.ToBoolPtr(true),
IdempotentHint: mcp.ToBoolPtr(false),
OpenWorldHint: mcp.ToBoolPtr(true),
}
s.server.AddTool(
tool,
s.handleExecuteGraphQL(),
)
s.registeredTools = append(s.registeredTools, "execute_graphql")
}
// Get operations filtered by the excludeMutations setting
operations := s.operationsManager.GetFilteredOperations()
graphqlOperationNames := make([]string, 0, len(operations))
for _, op := range operations {
var compiledSchema *jsonschema.Schema
var err error
graphqlOperationNames = append(graphqlOperationNames, op.Name)
if len(op.JSONSchema) > 0 {
// Validate the JSON schema before compiling it
if err := s.schemaCompiler.ValidateJSONSchema(op.JSONSchema); err != nil {
s.logger.Error("invalid schema for operation",
zap.String("operation", op.Name),
zap.Error(err))
continue
}
// Now compile the validated schema
schemaName := fmt.Sprintf("schema-%s.json", op.Name)
compiledSchema, err = s.schemaCompiler.CompileJSONSchema(op.JSONSchema, schemaName)
if err != nil {
s.logger.Error("failed to compile schema for operation",
zap.String("operation", op.Name),
zap.Error(err))
continue
}
}
// Create handler with pre-compiled schema
handler := &operationHandler{
operation: op,
compiledSchema: compiledSchema,
}
// Convert the operation name to snake_case for consistent tool naming
operationToolName := strcase.ToSnake(op.Name)
var toolDescription string
if op.Description != "" {
toolDescription = fmt.Sprintf("Executes the GraphQL operation '%s' of type %s. %s", op.Name, op.OperationType, op.Description)
} else {
toolDescription = fmt.Sprintf("Executes the GraphQL operation '%s' of type %s.", op.Name, op.OperationType)
}
toolName := fmt.Sprintf("execute_operation_%s", operationToolName)
tool := mcp.NewToolWithRawSchema(
toolName,
toolDescription,
op.JSONSchema,
)
tool.Annotations = mcp.ToolAnnotation{
IdempotentHint: mcp.ToBoolPtr(op.OperationType != "mutation"),
Title: fmt.Sprintf("Execute operation %s", op.Name),
ReadOnlyHint: mcp.ToBoolPtr(op.OperationType == "query"),
OpenWorldHint: mcp.ToBoolPtr(true),
}
s.server.AddTool(
tool,
s.handleOperation(handler),
)
s.registeredTools = append(s.registeredTools, toolName)
}
s.server.AddTool(
mcp.NewTool(
"get_operation_info",
mcp.WithDescription("Provides instructions on how to execute the GraphQL operation via HTTP and how to integrate it into your application."),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
Title: "Get GraphQL Operation Info",
ReadOnlyHint: mcp.ToBoolPtr(true),
}),
mcp.WithString("operationName",
mcp.Required(),
mcp.Description("The exact name of the GraphQL operation to retrieve information for."),
mcp.Enum(graphqlOperationNames...),
),
),
s.handleGraphQLOperationInfo(),
)
s.registeredTools = append(s.registeredTools, "get_operation_info")
return nil
}
// handleOperation handles a specific operation
func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
jsonBytes, err := json.Marshal(request.GetArguments())
if err != nil {
return nil, fmt.Errorf("failed to marshal arguments: %w", err)
}
// Validate the JSON input against the pre-compiled schema derived from the operation input type
if handler.compiledSchema != nil {
if err := s.schemaCompiler.ValidateInput(jsonBytes, handler.compiledSchema); err != nil {
return mcp.NewToolResultErrorFromErr("Input validation Error", err), nil
}
}
// Execute the operation with the provided variables
return s.executeGraphQLQuery(ctx, handler.operation.OperationString, jsonBytes)
}
}
// handleGraphQLOperationInfo returns a handler function that provides detailed info for a specific operation.
func (s *GraphQLSchemaServer) handleGraphQLOperationInfo() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var input GraphQLOperationInfoInput
inputBytes, err := json.Marshal(request.GetArguments())
if err != nil {
return nil, fmt.Errorf("failed to marshal input arguments: %w", err)
}
if err := json.Unmarshal(inputBytes, &input); err != nil {
return nil, fmt.Errorf("failed to unmarshal input arguments: %w. Ensure you provide {\"operationName\": \"<n>\"}", err)
}
if input.OperationName == "" {
return nil, fmt.Errorf("input validation failed: operationName is required")
}
targetOp := s.operationsManager.GetOperation(input.OperationName)
if targetOp == nil {
return nil, fmt.Errorf("operation '%s' not found or excluded by configuration", input.OperationName)
}
// Operation overview section
overview := fmt.Sprintf("Operation: %s\nType: %s\n", targetOp.Name, targetOp.OperationType)
if targetOp.Description != "" {
overview += fmt.Sprintf("Description: %s\n", targetOp.Description)
}
// Schema information section
var schemaInfo string
if len(targetOp.JSONSchema) > 0 {
schemaInfo = fmt.Sprintf("\nInput Schema:\n```json\n%s\n```\n", targetOp.JSONSchema)
} else {
schemaInfo = "\nThis operation does not require any input variables.\n"
}
// Query section
queryInfo := fmt.Sprintf("\nGraphQL Query:\n```\n%s\n```\n", targetOp.OperationString)
// Usage instructions section
usageInstructions := fmt.Sprintf(`
Usage Instructions:
1. Endpoint: %s
2. HTTP Method: POST
3. Headers Required:
- Content-Type: application/json; charset=utf-8
`, s.routerGraphQLEndpoint)
// Request format section
requestFormat := "\nRequest Format:\n```json\n"
if len(targetOp.JSONSchema) > 0 {
requestFormat += `{
"query": "<operation_query>",
"variables": <your_variables_object>
}
`
} else {
requestFormat += `{
"query": "<operation_query>"
}
`
}
requestFormat += "```"
// Important notes section
importantNotes := `
Important Notes:
1. Use the query string exactly as provided above
2. Do not modify or reformat the query string`
// Combine all sections
response := overview + schemaInfo + queryInfo + usageInstructions + requestFormat + importantNotes
return mcp.NewToolResultText(response), nil
}
}
// executeGraphQLQuery executes a GraphQL query against the router endpoint
func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query string, variables json.RawMessage) (*mcp.CallToolResult, error) {
// Create the GraphQL request
graphqlRequest := graphqlRequest{
Query: query,
Variables: variables,
}
graphqlRequestBytes, err := json.Marshal(graphqlRequest)
if err != nil {
return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err)
}
req, err := http.NewRequest("POST", s.routerGraphQLEndpoint, bytes.NewReader(graphqlRequestBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json; charset=utf-8")
token, err := tokenFromContext(ctx)
if err != nil {
s.logger.Debug("failed to get token from context", zap.Error(err))
} else if token != "" {
req.Header.Set("Authorization", token)
}
// Forward Authorization header if provided
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Parse the GraphQL response
var graphqlResponse GraphQLResponse
if err := json.Unmarshal(body, &graphqlResponse); err == nil && len(graphqlResponse.Errors) > 0 {
// Concatenate all error messages
var errorMessages []string
for _, gqlErr := range graphqlResponse.Errors {
errorMessages = append(errorMessages, gqlErr.Message)
}
errorMessage := strings.Join(errorMessages, "; ")
// If there are errors but no data, return only the errors
if len(graphqlResponse.Data) == 0 || string(graphqlResponse.Data) == "null" {
return mcp.NewToolResultErrorFromErr("Response Error", err), nil
}
// If we have both errors and data, include data in the error message
dataString := string(graphqlResponse.Data)
combinedErrorMsg := fmt.Sprintf("Response error with partial success, Error: %s, Data: %s)", errorMessage, dataString)
return mcp.NewToolResultErrorFromErr(combinedErrorMsg, err), nil
}
return mcp.NewToolResultText(string(body)), nil
}
// handleExecuteGraphQL returns a handler function that executes arbitrary GraphQL queries
func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Parse the JSON input
jsonBytes, err := json.Marshal(request.GetArguments())
if err != nil {
return nil, fmt.Errorf("failed to marshal arguments: %w", err)
}
var input ExecuteGraphQLInput
if err := json.Unmarshal(jsonBytes, &input); err != nil {
return nil, fmt.Errorf("failed to unmarshal input arguments: %w", err)
}
if input.Query == "" {
return nil, fmt.Errorf("input validation failed: query is required")
}
return s.executeGraphQLQuery(ctx, input.Query, input.Variables)
}
}
// handleGetGraphQLSchema returns a handler function that returns the full GraphQL schema
func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Get the schema from the operations manager
schema := s.operationsManager.GetSchema()
if schema == nil {
return nil, fmt.Errorf("GraphQL schema is not available")
}
// Convert the AST document to a string representation
schemaStr, err := astprinter.PrintString(schema)
if err != nil {
return nil, fmt.Errorf("failed to convert schema to string: %w", err)
}
return mcp.NewToolResultText(schemaStr), nil
}
}