plugins/wasm-go/extensions/ai-proxy/main.go (300 lines of code) (raw):

// File generated by hgctl. Modify as required. // See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 package main import ( "fmt" "net/url" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/log" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( pluginName = "ai-proxy" defaultMaxBodyBytes uint32 = 100 * 1024 * 1024 ) func main() { wrapper.SetCtx( pluginName, wrapper.ParseOverrideConfig(parseGlobalConfig, parseOverrideRuleConfig), wrapper.ProcessRequestHeaders(onHttpRequestHeader), wrapper.ProcessRequestBody(onHttpRequestBody), wrapper.ProcessResponseHeaders(onHttpResponseHeaders), wrapper.ProcessStreamingResponseBody(onStreamingResponseBody), wrapper.ProcessResponseBody(onHttpResponseBody), ) } func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig) error { log.Debugf("loading global config: %s", json.String()) pluginConfig.FromJson(json) if err := pluginConfig.Validate(); err != nil { log.Errorf("global rule config is invalid: %v", err) return err } if err := pluginConfig.Complete(); err != nil { log.Errorf("failed to apply global rule config: %v", err) return err } return nil } func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig) error { log.Debugf("loading override rule config: %s", json.String()) *pluginConfig = global pluginConfig.FromJson(json) if err := pluginConfig.Validate(); err != nil { log.Errorf("overriden rule config is invalid: %v", err) return err } if err := pluginConfig.Complete(); err != nil { log.Errorf("failed to apply overriden rule config: %v", err) return err } return nil } func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpRequestHeader] no active provider, skip processing") ctx.DontReadRequestBody() return types.ActionContinue } log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType()) rawPath := ctx.Path() path, _ := url.Parse(rawPath) apiName := getApiName(path.Path) providerConfig := pluginConfig.GetProviderConfig() if providerConfig.IsOriginal() { if handler, ok := activeProvider.(provider.ApiNameHandler); ok { apiName = handler.GetApiName(path.Path) } } if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) { ctx.DontReadRequestBody() log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType) } if apiName == "" { ctx.DontReadRequestBody() ctx.DontReadResponseBody() log.Warnf("[onHttpRequestHeader] unsupported path: %s, will not process http path and body", path.Path) } ctx.SetContext(provider.CtxKeyApiName, apiName) // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() // Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses, // allowing plugins to inspect or modify the response correctly _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx) // Set available apiTokens of current request in the context, will be used in the retryOnFailure providerConfig.SetAvailableApiTokens(ctx) // save the original request host and path in case they are needed for apiToken health check and retry ctx.SetContext(provider.CtxRequestHost, wrapper.GetRequestHost()) ctx.SetContext(provider.CtxRequestPath, wrapper.GetRequestPath()) err := handler.OnRequestHeaders(ctx, apiName) if err != nil { _ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) return types.ActionContinue } hasRequestBody := wrapper.HasRequestBody() if hasRequestBody { _ = proxywasm.RemoveHttpRequestHeader("Content-Length") ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) // Delay the header processing to allow changing in OnRequestBody return types.HeaderStopIteration } ctx.DontReadRequestBody() return types.ActionContinue } return types.ActionContinue } func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpRequestBody] no active provider, skip processing") return types.ActionContinue } log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType()) if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) providerConfig := pluginConfig.GetProviderConfig() // If retryOnFailure is enabled, save the transformed body to the context in case of retry if providerConfig.IsRetryOnFailureEnabled() { ctx.SetContext(provider.CtxRequestBody, body) } newBody, settingErr := providerConfig.ReplaceByCustomSettings(body) if settingErr != nil { log.Errorf("failed to replace request body by custom settings: %v", settingErr) } if providerConfig.IsOpenAIProtocol() { newBody = normalizeOpenAiRequestBody(newBody) } log.Debugf("[onHttpRequestBody] newBody=%s", newBody) body = newBody action, err := handler.OnRequestBody(ctx, apiName, body) if err == nil { return action } _ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) } return types.ActionContinue } func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action { if !wrapper.IsResponseFromUpstream() { // Response is not coming from the upstream. Let it pass through. ctx.DontReadResponseBody() return types.ActionContinue } activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpResponseHeaders] no active provider, skip processing") ctx.DontReadResponseBody() return types.ActionContinue } log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType()) providerConfig := pluginConfig.GetProviderConfig() apiTokenInUse := providerConfig.GetApiTokenInUse(ctx) apiTokens := providerConfig.GetAvailableApiToken(ctx) status, err := proxywasm.GetHttpResponseHeader(":status") if err != nil || status != "200" { if err != nil { log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status) } // Reset ctxApiTokenRequestFailureCount if the request is successful, // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse) headers := util.GetOriginalResponseHeaders() if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) handler.TransformResponseHeaders(ctx, apiName, headers) } else { providerConfig.DefaultTransformResponseHeaders(ctx, headers) } util.ReplaceResponseHeaders(headers) checkStream(ctx) _, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler) var needHandleStreamingBody bool _, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler) if !needHandleStreamingBody { _, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler) } if !needHandleBody && !needHandleStreamingBody { ctx.DontReadResponseBody() } else if !needHandleStreamingBody { ctx.BufferResponseBody() } return types.ActionContinue } func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool) []byte { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onStreamingResponseBody] no active provider, skip processing") return chunk } log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType()) log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk) if err == nil && modifiedChunk != nil { return modifiedChunk } return chunk } if handler, ok := activeProvider.(provider.StreamingEventHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) events := provider.ExtractStreamingEvents(ctx, chunk) log.Debugf("[onStreamingResponseBody] %d events received", len(events)) if len(events) == 0 { // No events are extracted, return the original chunk return chunk } var responseBuilder strings.Builder for _, event := range events { log.Debugf("processing event: %v", event) if event.IsEndData() { responseBuilder.WriteString(event.ToHttpString()) continue } outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event) if err != nil { log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk) return chunk } if outputEvents == nil || len(outputEvents) == 0 { responseBuilder.WriteString(event.ToHttpString()) } else { for _, outputEvent := range outputEvents { responseBuilder.WriteString(outputEvent.ToHttpString()) } } } return []byte(responseBuilder.String()) } return chunk } func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpResponseBody] no active provider, skip processing") return types.ActionContinue } log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType()) if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) body, err := handler.TransformResponseBody(ctx, apiName, body) if err != nil { _ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err)) return types.ActionContinue } if err = provider.ReplaceResponseBody(body); err != nil { _ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err)) } } return types.ActionContinue } func normalizeOpenAiRequestBody(body []byte) []byte { var err error // Default setting include_usage. if gjson.GetBytes(body, "stream").Bool() { body, err = sjson.SetBytes(body, "stream_options.include_usage", true) if err != nil { log.Errorf("set include_usage failed, err:%s", err) } } return body } func checkStream(ctx wrapper.HttpContext) { contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { if err != nil { log.Errorf("unable to load content-type header from response: %v", err) } ctx.BufferResponseBody() ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes) } } func getApiName(path string) provider.ApiName { // openai style if strings.HasSuffix(path, "/v1/chat/completions") { return provider.ApiNameChatCompletion } if strings.HasSuffix(path, "/v1/completions") { return provider.ApiNameCompletion } if strings.HasSuffix(path, "/v1/embeddings") { return provider.ApiNameEmbeddings } if strings.HasSuffix(path, "/v1/audio/speech") { return provider.ApiNameAudioSpeech } if strings.HasSuffix(path, "/v1/images/generations") { return provider.ApiNameImageGeneration } if strings.HasSuffix(path, "/v1/batches") { return provider.ApiNameBatches } if strings.HasSuffix(path, "/v1/files") { return provider.ApiNameFiles } if strings.HasSuffix(path, "/v1/models") { return provider.ApiNameModels } // cohere style if strings.HasSuffix(path, "/v1/rerank") { return provider.ApiNameCohereV1Rerank } return "" }