plugins/wasm-go/extensions/ai-proxy/main.go (300 lines of code) (raw):
// File generated by hgctl. Modify as required.
// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6
package main
import (
"fmt"
"net/url"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
pluginName = "ai-proxy"
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
)
func main() {
wrapper.SetCtx(
pluginName,
wrapper.ParseOverrideConfig(parseGlobalConfig, parseOverrideRuleConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeader),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
wrapper.ProcessResponseBody(onHttpResponseBody),
)
}
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig) error {
log.Debugf("loading global config: %s", json.String())
pluginConfig.FromJson(json)
if err := pluginConfig.Validate(); err != nil {
log.Errorf("global rule config is invalid: %v", err)
return err
}
if err := pluginConfig.Complete(); err != nil {
log.Errorf("failed to apply global rule config: %v", err)
return err
}
return nil
}
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig) error {
log.Debugf("loading override rule config: %s", json.String())
*pluginConfig = global
pluginConfig.FromJson(json)
if err := pluginConfig.Validate(); err != nil {
log.Errorf("overriden rule config is invalid: %v", err)
return err
}
if err := pluginConfig.Complete(); err != nil {
log.Errorf("failed to apply overriden rule config: %v", err)
return err
}
return nil
}
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
activeProvider := pluginConfig.GetProvider()
if activeProvider == nil {
log.Debugf("[onHttpRequestHeader] no active provider, skip processing")
ctx.DontReadRequestBody()
return types.ActionContinue
}
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
rawPath := ctx.Path()
path, _ := url.Parse(rawPath)
apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
apiName = handler.GetApiName(path.Path)
}
}
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
ctx.DontReadRequestBody()
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
}
if apiName == "" {
ctx.DontReadRequestBody()
ctx.DontReadResponseBody()
log.Warnf("[onHttpRequestHeader] unsupported path: %s, will not process http path and body", path.Path)
}
ctx.SetContext(provider.CtxKeyApiName, apiName)
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses,
// allowing plugins to inspect or modify the response correctly
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx)
// Set available apiTokens of current request in the context, will be used in the retryOnFailure
providerConfig.SetAvailableApiTokens(ctx)
// save the original request host and path in case they are needed for apiToken health check and retry
ctx.SetContext(provider.CtxRequestHost, wrapper.GetRequestHost())
ctx.SetContext(provider.CtxRequestPath, wrapper.GetRequestPath())
err := handler.OnRequestHeaders(ctx, apiName)
if err != nil {
_ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}
hasRequestBody := wrapper.HasRequestBody()
if hasRequestBody {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
ctx.DontReadRequestBody()
return types.ActionContinue
}
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
activeProvider := pluginConfig.GetProvider()
if activeProvider == nil {
log.Debugf("[onHttpRequestBody] no active provider, skip processing")
return types.ActionContinue
}
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
providerConfig := pluginConfig.GetProviderConfig()
// If retryOnFailure is enabled, save the transformed body to the context in case of retry
if providerConfig.IsRetryOnFailureEnabled() {
ctx.SetContext(provider.CtxRequestBody, body)
}
newBody, settingErr := providerConfig.ReplaceByCustomSettings(body)
if settingErr != nil {
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
}
if providerConfig.IsOpenAIProtocol() {
newBody = normalizeOpenAiRequestBody(newBody)
}
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
body = newBody
action, err := handler.OnRequestBody(ctx, apiName, body)
if err == nil {
return action
}
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
}
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
if !wrapper.IsResponseFromUpstream() {
// Response is not coming from the upstream. Let it pass through.
ctx.DontReadResponseBody()
return types.ActionContinue
}
activeProvider := pluginConfig.GetProvider()
if activeProvider == nil {
log.Debugf("[onHttpResponseHeaders] no active provider, skip processing")
ctx.DontReadResponseBody()
return types.ActionContinue
}
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
providerConfig := pluginConfig.GetProviderConfig()
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
apiTokens := providerConfig.GetAvailableApiToken(ctx)
status, err := proxywasm.GetHttpResponseHeader(":status")
if err != nil || status != "200" {
if err != nil {
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status)
}
// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)
headers := util.GetOriginalResponseHeaders()
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
handler.TransformResponseHeaders(ctx, apiName, headers)
} else {
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
}
util.ReplaceResponseHeaders(headers)
checkStream(ctx)
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
var needHandleStreamingBody bool
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleStreamingBody {
_, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler)
}
if !needHandleBody && !needHandleStreamingBody {
ctx.DontReadResponseBody()
} else if !needHandleStreamingBody {
ctx.BufferResponseBody()
}
return types.ActionContinue
}
func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool) []byte {
activeProvider := pluginConfig.GetProvider()
if activeProvider == nil {
log.Debugf("[onStreamingResponseBody] no active provider, skip processing")
return chunk
}
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
if err == nil && modifiedChunk != nil {
return modifiedChunk
}
return chunk
}
if handler, ok := activeProvider.(provider.StreamingEventHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
events := provider.ExtractStreamingEvents(ctx, chunk)
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
if len(events) == 0 {
// No events are extracted, return the original chunk
return chunk
}
var responseBuilder strings.Builder
for _, event := range events {
log.Debugf("processing event: %v", event)
if event.IsEndData() {
responseBuilder.WriteString(event.ToHttpString())
continue
}
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event)
if err != nil {
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
return chunk
}
if outputEvents == nil || len(outputEvents) == 0 {
responseBuilder.WriteString(event.ToHttpString())
} else {
for _, outputEvent := range outputEvents {
responseBuilder.WriteString(outputEvent.ToHttpString())
}
}
}
return []byte(responseBuilder.String())
}
return chunk
}
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
activeProvider := pluginConfig.GetProvider()
if activeProvider == nil {
log.Debugf("[onHttpResponseBody] no active provider, skip processing")
return types.ActionContinue
}
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
body, err := handler.TransformResponseBody(ctx, apiName, body)
if err != nil {
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
if err = provider.ReplaceResponseBody(body); err != nil {
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
}
}
return types.ActionContinue
}
func normalizeOpenAiRequestBody(body []byte) []byte {
var err error
// Default setting include_usage.
if gjson.GetBytes(body, "stream").Bool() {
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
if err != nil {
log.Errorf("set include_usage failed, err:%s", err)
}
}
return body
}
func checkStream(ctx wrapper.HttpContext) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil {
log.Errorf("unable to load content-type header from response: %v", err)
}
ctx.BufferResponseBody()
ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes)
}
}
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/completions") {
return provider.ApiNameCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
if strings.HasSuffix(path, "/v1/audio/speech") {
return provider.ApiNameAudioSpeech
}
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
if strings.HasSuffix(path, "/v1/batches") {
return provider.ApiNameBatches
}
if strings.HasSuffix(path, "/v1/files") {
return provider.ApiNameFiles
}
if strings.HasSuffix(path, "/v1/models") {
return provider.ApiNameModels
}
// cohere style
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
}
return ""
}