plugins/golang-filter/mcp-session/filter.go (408 lines of code) (raw):
package mcp_session
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/handler"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
const (
RedisNotEnabledResponseBody = "Redis is not enabled, SSE connection is not supported"
)
// The callbacks in the filter, like `DecodeHeaders`, can be implemented on demand.
// Because api.PassThroughStreamFilter provides a default implementation.
type filter struct {
api.PassThroughStreamFilter
callbacks api.FilterCallbackHandler
path string
config *config
stopChan chan struct{}
req *http.Request
serverName string
proxyURL *url.URL
matchedRule common.MatchRule
needProcess bool
skipRequestBody bool
skipResponseBody bool
cachedResponseBody []byte
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
ratelimit bool
mcpRatelimitHandler *handler.MCPRatelimitHandler
}
// Callbacks which are called in request path
// The endStream is true if the request doesn't have body
func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
requestUrl := common.NewRequestURL(header)
if requestUrl == nil {
return api.Continue
}
f.path = requestUrl.ParsedURL.Path
// Check if request matches any rule in match_list
matched, matchedRule := common.IsMatch(f.config.matchList, requestUrl.Host, f.path)
if !matched {
api.LogDebugf("Request does not match any rule in match_list: %s", requestUrl.ParsedURL.String())
return api.Continue
}
f.needProcess = true
f.matchedRule = matchedRule
f.req = &http.Request{
Method: requestUrl.Method,
URL: requestUrl.ParsedURL,
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && f.config.enableUserLevelServer {
if !requestUrl.InternalIP {
api.LogWarnf("Access denied: non-Internal IP address %s", requestUrl.ParsedURL.String())
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
if strings.HasSuffix(f.path, ConfigPathSuffix) && requestUrl.Method == http.MethodGet {
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.req, []byte{})
return api.LocalReply
}
f.userLevelConfig = true
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
return f.processMcpRequestHeaders(header, endStream)
}
func (f *filter) processMcpRequestHeaders(header api.RequestHeaderMap, endStream bool) api.StatusType {
switch f.matchedRule.UpstreamType {
case common.RestUpstream, common.StreamableUpstream:
return f.processMcpRequestHeadersForRestUpstream(header, endStream)
case common.SSEUpstream:
return f.processMcpRequestHeadersForSSEUpstream(header, endStream)
}
f.needProcess = false
return api.Continue
}
func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType {
method := f.req.Method
requestUrl := f.req.URL
if !strings.HasSuffix(requestUrl.Path, GlobalSSEPathSuffix) {
f.proxyURL = requestUrl
if f.config.enableUserLevelServer {
parts := strings.Split(requestUrl.Path, "/")
if len(parts) >= 3 {
serverName := parts[1]
uid := parts[2]
// Get encoded config
encodedConfig, _ := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if encodedConfig != "" {
header.Set("x-higress-mcpserver-config", encodedConfig)
api.LogDebugf("Set x-higress-mcpserver-config Header for %s:%s", serverName, uid)
}
}
f.ratelimit = true
}
if endStream {
return api.Continue
} else {
return api.StopAndBuffer
}
}
if method != http.MethodGet {
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
} else {
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)),
common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName()
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
return api.LocalReply
}
func (f *filter) processMcpRequestHeadersForSSEUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType {
// We don't need to process the request body for SSE upstream.
f.skipRequestBody = true
f.rewritePathForSSEUpstream(header)
return api.Continue
}
func (f *filter) rewritePathForSSEUpstream(header api.RequestHeaderMap) {
matchedRule := f.matchedRule
if !matchedRule.EnablePathRewrite || matchedRule.MatchRuleType != common.PrefixMatch {
// No rewrite required, so we don't need to process the response body, either.
f.skipResponseBody = true
return
}
path := f.req.URL.Path
if !strings.HasPrefix(path, matchedRule.MatchRulePath) {
api.LogWarnf("Unexpected: Path %s does not match the configured prefix %s", path, matchedRule.MatchRulePath)
return
}
rewrittenPath := path[len(matchedRule.MatchRulePath):]
if rewrittenPath == "" {
rewrittenPath = matchedRule.PathRewritePrefix
} else {
rewritePrefixHasTrailingSlash := strings.HasSuffix(matchedRule.PathRewritePrefix, "/")
pathSuffixHasLeadingSlash := strings.HasPrefix(rewrittenPath, "/")
if rewritePrefixHasTrailingSlash != pathSuffixHasLeadingSlash {
// One has, the other doesn't have.
rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath
} else if pathSuffixHasLeadingSlash {
// Both have.
rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath[1:]
} else {
// Neither have.
rewrittenPath = matchedRule.PathRewritePrefix + "/" + rewrittenPath
}
}
if f.req.URL.RawQuery != "" {
rewrittenPath = rewrittenPath + "?" + f.req.URL.RawQuery
}
header.SetPath(rewrittenPath)
}
// DecodeData might be called multiple times during handling the request body.
// The endStream is true when handling the last piece of the body.
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if !f.needProcess || f.skipRequestBody {
return api.Continue
}
if f.matchedRule.UpstreamType != common.RestUpstream && f.matchedRule.UpstreamType != common.StreamableUpstream {
return api.Continue
}
if !endStream {
return api.StopAndBuffer
}
if f.userLevelConfig {
// Handle config POST request
api.LogDebugf("Handling config request: %s", f.path)
f.mcpConfigHandler.HandleConfigRequest(f.req, buffer.Bytes())
return api.LocalReply
} else if f.ratelimit {
if checkJSONRPCMethod(buffer.Bytes(), "tools/list") {
api.LogDebugf("Not a tools call request, skipping ratelimit")
return api.Continue
}
parts := strings.Split(f.req.URL.Path, "/")
if len(parts) < 3 {
api.LogWarnf("Access denied: no valid uid found")
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
}
serverName := parts[1]
uid := parts[2]
encodedConfig, err := f.mcpConfigHandler.GetEncodedConfig(serverName, uid)
if err != nil {
api.LogWarnf("Access denied: no valid config found for uid %s", uid)
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusForbidden, "", nil, 0, "")
return api.LocalReply
} else if encodedConfig == "" && checkJSONRPCMethod(buffer.Bytes(), "tools/call") {
api.LogDebugf("Empty config found for %s:%s", serverName, uid)
if !f.mcpRatelimitHandler.HandleRatelimit(f.req, buffer.Bytes()) {
return api.LocalReply
}
}
}
return api.Continue
}
// EncodeHeaders Callbacks which are called in response path.
// The endStream is true if the response doesn't have body.
func (f *filter) EncodeHeaders(header api.ResponseHeaderMap, endStream bool) api.StatusType {
if !f.needProcess {
return api.Continue
}
if f.matchedRule.UpstreamType != common.RestUpstream && f.matchedRule.UpstreamType != common.StreamableUpstream {
if contentType, ok := header.Get("content-type"); !ok || !strings.HasPrefix(contentType, "text/event-stream") {
api.LogDebugf("Skip response body for non-SSE upstream. Content-Type: %s", contentType)
f.skipResponseBody = true
}
return api.Continue
}
if f.serverName != "" {
if f.config.redisClient != nil {
header.Set("Content-Type", "text/event-stream")
header.Set("Cache-Control", "no-cache")
header.Set("Connection", "keep-alive")
header.Set("Access-Control-Allow-Origin", "*")
header.Del("Content-Length")
} else {
header.Set("Content-Length", strconv.Itoa(len(RedisNotEnabledResponseBody)))
}
return api.Continue
}
return api.Continue
}
// EncodeData might be called multiple times during handling the response body.
// The endStream is true when handling the last piece of the body.
func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if !f.needProcess || f.skipResponseBody {
return api.Continue
}
ret := api.Continue
api.LogDebugf("Upstream Type: %s", f.matchedRule.UpstreamType)
switch f.matchedRule.UpstreamType {
case common.RestUpstream, common.StreamableUpstream:
api.LogDebugf("Encoding data from Rest upstream")
ret = f.encodeDataFromRestUpstream(buffer, endStream)
break
case common.SSEUpstream:
api.LogDebugf("Encoding data from SSE upstream")
ret = f.encodeDataFromSSEUpstream(buffer, endStream)
if endStream {
// Always continue as long as the stream has ended.
ret = api.Continue
}
}
return ret
}
func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream bool) api.StatusType {
if !f.needProcess {
return api.Continue
}
if !endStream {
return api.StopAndBuffer
}
if f.proxyURL != nil && f.config.redisClient != nil {
sessionID := f.proxyURL.Query().Get("sessionId")
if sessionID != "" {
channel := common.GetSSEChannelName(sessionID)
eventData := fmt.Sprintf("event: message\ndata: %s\n\n", buffer.String())
publishErr := f.config.redisClient.Publish(channel, eventData)
if publishErr != nil {
api.LogErrorf("Failed to publish wasm mcp server message to Redis: %v", publishErr)
}
}
}
if f.serverName != "" {
if f.config.redisClient != nil {
// handle default server
buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running
} else {
_ = buffer.SetString(RedisNotEnabledResponseBody)
return api.Continue
}
}
return api.Continue
}
func (f *filter) encodeDataFromSSEUpstream(buffer api.BufferInstance, endStream bool) api.StatusType {
bufferBytes := buffer.Bytes()
bufferData := string(bufferBytes)
err, lineBreak := f.findSSELineBreak(bufferData)
if err != nil {
api.LogWarnf("Failed to find line break in SSE data: %v", err)
f.needProcess = false
return api.Continue
}
if lineBreak == "" {
// Have not found any line break. Need to buffer and check again.
return api.StopAndBuffer
}
api.LogDebugf("Line break sequence: %v", []byte(lineBreak))
err, endpointUrl := f.findEndpointUrl(bufferData, lineBreak)
if err != nil {
api.LogWarnf("Failed to find endpoint URL in SSE data: %v", err)
f.needProcess = false
return api.Continue
}
if endpointUrl == "" {
// No endpoint URL found. Need to buffer and check again.
return api.StopAndBuffer
}
// Remove query string since we don't need to change it.
queryStringIndex := strings.IndexAny(endpointUrl, "?")
if queryStringIndex != -1 {
endpointUrl = endpointUrl[:queryStringIndex]
}
if changed, newEndpointUrl := f.rewriteEndpointUrl(endpointUrl); changed {
api.LogDebugf("The endpoint URL is changed.\n Old: %s\n New: %s", endpointUrl, newEndpointUrl)
endpointUrlIndex := strings.Index(bufferData, endpointUrl)
if endpointUrlIndex == -1 {
api.LogWarnf("Something wrong, the previously found endpoint URL %s not found in the SSE data now", endpointUrl)
} else {
bufferData = bufferData[:endpointUrlIndex] + newEndpointUrl + bufferData[endpointUrlIndex+len(endpointUrl):]
_ = buffer.SetString(bufferData)
}
} else {
api.LogDebugf("The endpoint URL %s is not changed", endpointUrl)
}
f.needProcess = false
return api.Continue
}
func (f *filter) rewriteEndpointUrl(endpointUrl string) (bool, string) {
if !f.matchedRule.EnablePathRewrite {
return false, ""
}
if schemeIndex := strings.Index(endpointUrl, "://"); schemeIndex != -1 {
endpointUrl = endpointUrl[schemeIndex+3:]
if slashIndex := strings.Index(endpointUrl, "/"); slashIndex != -1 {
endpointUrl = endpointUrl[slashIndex:]
} else {
endpointUrl = "/"
}
}
if !strings.HasPrefix(endpointUrl, f.matchedRule.PathRewritePrefix) {
// The endpoint URL does not match the path rewrite prefix. We are unable to rewrite it back.
api.LogWarnf("The endpoint URL %s does not match the path rewrite prefix %s", endpointUrl, f.matchedRule.PathRewritePrefix)
return false, ""
}
suffix := endpointUrl[len(f.matchedRule.PathRewritePrefix):]
if len(suffix) == 0 {
endpointUrl = f.matchedRule.MatchRulePath
} else {
matchPathHasTrailingSlash := strings.HasSuffix(f.matchedRule.MatchRulePath, "/")
suffixHasLeadingSlash := strings.HasPrefix(suffix, "/")
if matchPathHasTrailingSlash != suffixHasLeadingSlash {
// One has, the other doesn't have.
endpointUrl = f.matchedRule.MatchRulePath + suffix
} else if matchPathHasTrailingSlash {
// Both have.
endpointUrl = f.matchedRule.MatchRulePath + suffix[1:]
} else {
// Neither have.
endpointUrl = f.matchedRule.MatchRulePath + "/" + suffix
}
}
return true, endpointUrl
}
func (f *filter) findSSELineBreak(bufferData string) (error, string) {
// See https://html.spec.whatwg.org/multipage/server-sent-events.html
crIndex := strings.IndexAny(bufferData, "\r")
lfIndex := strings.IndexAny(bufferData, "\n")
if crIndex == -1 && lfIndex == -1 {
// No line break found.
return nil, ""
}
lineBreak := ""
if crIndex != -1 && lfIndex != -1 {
if crIndex+1 != lfIndex {
// Found both line breaks, but they are not adjacent. Skip body processing.
return errors.New("found non-adjacent CR and LF"), ""
}
lineBreak = "\r\n"
} else if crIndex != -1 {
lineBreak = "\r"
} else {
lineBreak = "\n"
}
return nil, lineBreak
}
func (f *filter) findEndpointUrl(bufferData, lineBreak string) (error, string) {
eventIndex := strings.Index(bufferData, "event:")
if eventIndex == -1 {
return nil, ""
}
bufferData = bufferData[eventIndex:]
eventEndIndex := strings.Index(bufferData, lineBreak)
if eventEndIndex == -1 {
return nil, ""
}
eventName := strings.TrimSpace(bufferData[len("event:"):eventEndIndex])
if eventName != "endpoint" {
return fmt.Errorf("the initial event [%s] is not an endpoint event. Skip processing", eventName), ""
}
bufferData = bufferData[eventEndIndex+len(lineBreak):]
dataEndIndex := strings.Index(bufferData, lineBreak)
if dataEndIndex == -1 {
// Data received not enough.
return nil, ""
}
eventData := bufferData[:dataEndIndex]
if !strings.HasPrefix(eventData, "data:") {
return fmt.Errorf("an unexpected non-data field found in the event. Skip processing. Field: %s", eventData), ""
}
return nil, strings.TrimSpace(eventData[len("data:"):])
}
// OnDestroy stops the goroutine
func (f *filter) OnDestroy(reason api.DestroyReason) {
api.LogDebugf("OnDestroy: reason=%v", reason)
if f.serverName != "" && f.stopChan != nil {
select {
case <-f.stopChan:
return
default:
api.LogDebug("Stopping SSE connection")
close(f.stopChan)
}
}
}
// check if the request is a tools/call request
func checkJSONRPCMethod(body []byte, method string) bool {
var request mcp.CallToolRequest
if err := json.Unmarshal(body, &request); err != nil {
api.LogWarnf("Failed to unmarshal request body: %v, not a JSON RPC request", err)
return true
}
return request.Method == method
}