plugins/wasm-go/extensions/ai-json-resp/main.go (452 lines of code) (raw):

// Copyright (c) 2022 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "encoding/json" "errors" "net/http" "strconv" "strings" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/santhosh-tekuri/jsonschema" "github.com/tidwall/gjson" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) const ( DEFAULT_SCHEMA = "defaultSchema" HTTP_STATUS_OK = uint32(200) HTTP_STATUS_INTERNAL_SERVER_ERROR = uint32(500) FROM_THIS_PLUGIN_KEY = "fromThisPlugin" EXTEND_HEADER_KEY = "X-HIGRESS-AI-JSON-RESP" JSON_SCHEMA_INVALID_CODE = 1001 JSON_SCHEMA_COMPILE_FAILED_CODE = 1002 CANNOT_FIND_JSON_IN_RESPONSE_CODE = 1003 CONTENT_IS_EMPTY_CODE = 1004 JSON_MISMATCH_SCHEMA_CODE = 1005 REACH_MAX_RETRY_COUNT_CODE = 1006 SERVICE_UNAVAILABLE_CODE = 1007 SERVICE_CONFIG_INVALID_CODE = 1008 ) type RejectStruct struct { RejectCode uint32 `json:"Code"` RejectMsg string `json:"Msg"` } func (r RejectStruct) GetBytes() []byte { jsonData, _ := json.Marshal(r) return jsonData } func (r RejectStruct) GetShortMsg() string { return "ai-json-resp." + strings.Split(r.RejectMsg, ":")[0] } type PluginConfig struct { // @Title zh-CN 服务名称 // @Description zh-CN 用以请求服务的名称(网关或其他AI服务) serviceName string `required:"true" json:"serviceName" yaml:"serviceName"` // @Title zh-CN 服务域名 // @Description zh-CN 用以请求服务的域名 serviceDomain string `required:"false" json:"serviceDomain" yaml:"serviceDomain"` // @Title zh-CN 服务端口 // @Description zh-CN 用以请求服务的端口 servicePort int `required:"false" json:"servicePort" yaml:"servicePort"` // @Title zh-CN 服务URL // @Description zh-CN 用以请求服务的URL,若提供则会覆盖serviceDomain和servicePort serviceUrl string `required:"false" json:"serviceUrl" yaml:"serviceUrl"` // @Title zh-CN API Key // @Description zh-CN 若使用AI服务,需要填写请求服务的API Key apiKey string `required:"false" json: "apiKey" yaml:"apiKey"` // @Title zh-CN 请求端点 // @Description zh-CN 用以请求服务的端点, 默认为"/v1/chat/completions" servicePath string `required:"false" json: "servicePath" yaml:"servicePath"` // @Title zh-CN 服务超时时间 // @Description zh-CN 用以请求服务的超时时间 serviceTimeout int `required:"false" json:"serviceTimeout" yaml:"serviceTimeout"` // @Title zh-CN 最大重试次数 // @Description zh-CN 用以请求服务的最大重试次数 maxRetry int `required:"false" json:"maxRetry" yaml:"maxRetry"` // @Title zh-CN 内容路径 // @Description zh-CN 从AI服务返回的响应中提取json的gpath路径 contentPath string `required:"false" json:"contentPath" yaml:"contentPath"` // @Title zh-CN Json Schema // @Description zh-CN 用以验证响应json的Json Schema, 为空则只验证返回的响应是否为合法json jsonSchema map[string]interface{} `required:"false" json:"jsonSchema" yaml:"jsonSchema"` // @Title zh-CN 是否启用swagger // @Description zh-CN 是否启用swagger进行Json Schema验证 enableSwagger bool `required:"false" json:"enableSwagger" yaml:"enableSwagger"` // @Title zh-CN 是否启用oas3 // @Description zh-CN 是否启用oas3进行Json Schema验证 enableOas3 bool `required:"false" json:"enableOas3" yaml:"enableOas3"` // @Title zh-CN 是否启用Content-Disposition // @Description zh-CN 是否启用Content-Disposition, 若启用则会在响应头中添加Content-Disposition: attachment; filename="response.json" enableContentDisposition bool `required:"false" json:"enableContentDisposition" yaml:"enableContentDisposition"` serviceClient wrapper.HttpClient draft *jsonschema.Draft compiler *jsonschema.Compiler compile *jsonschema.Schema rejectStruct RejectStruct jsonSchemaMaxDepth int enableJsonSchemaValidation bool } func main() { wrapper.SetCtx( "ai-json-resp", wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), ) } type RequestContext struct { Path string ReqHeaders [][2]string ReqBody []byte RespHeader [][2]string RespBody []byte HistoryMessages []chatMessage } func parseUrl(url string) (string, string) { if url == "" { return "", "" } url = strings.TrimPrefix(url, "http://") url = strings.TrimPrefix(url, "https://") index := strings.Index(url, "/") if index == -1 { return url, "" } return url[:index], url[index:] } func parseConfig(result gjson.Result, config *PluginConfig, log wrapper.Log) error { config.serviceName = result.Get("serviceName").String() config.serviceUrl = result.Get("serviceUrl").String() config.serviceDomain = result.Get("serviceDomain").String() config.servicePath = result.Get("servicePath").String() config.servicePort = int(result.Get("servicePort").Int()) if config.serviceUrl != "" { domain, url := parseUrl(config.serviceUrl) log.Debugf("serviceUrl: %s, the parsed domain: %s, the parsed url: %s", config.serviceUrl, domain, url) if config.serviceDomain == "" { config.serviceDomain = domain } if config.servicePath == "" { config.servicePath = url } } if config.servicePort == 0 { config.servicePort = 443 } config.serviceTimeout = int(result.Get("serviceTimeout").Int()) config.apiKey = result.Get("apiKey").String() config.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""} if config.serviceTimeout == 0 { config.serviceTimeout = 50000 } config.maxRetry = int(result.Get("maxRetry").Int()) if config.maxRetry == 0 { config.maxRetry = 3 } config.contentPath = result.Get("contentPath").String() if config.contentPath == "" { config.contentPath = "choices.0.message.content" } if jsonSchemaValue := result.Get("jsonSchema"); jsonSchemaValue.Exists() { if schemaValue, ok := jsonSchemaValue.Value().(map[string]interface{}); ok { config.jsonSchema = schemaValue } else { config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema is not valid"} } } else { config.jsonSchema = nil } if config.serviceDomain == "" { config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "service domain is empty"} } config.serviceClient = wrapper.NewClusterClient(wrapper.DnsCluster{ ServiceName: config.serviceName, Port: int64(config.servicePort), Domain: config.serviceDomain, }) enableSwagger := result.Get("enableSwagger").Bool() enableOas3 := result.Get("enableOas3").Bool() // set draft version if enableSwagger { config.draft = jsonschema.Draft4 } if enableOas3 { config.draft = jsonschema.Draft7 } if !enableSwagger && !enableOas3 { config.draft = jsonschema.Draft7 } // create compiler compiler := jsonschema.NewCompiler() compiler.Draft = config.draft config.compiler = compiler // set max depth of json schema config.jsonSchemaMaxDepth = 6 enableContentDispositionValue := result.Get("enableContentDisposition") if !enableContentDispositionValue.Exists() { config.enableContentDisposition = true } else { config.enableContentDisposition = enableContentDispositionValue.Bool() } config.enableJsonSchemaValidation = true jsonSchemaBytes, err := json.Marshal(config.jsonSchema) if err != nil { config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema marshal failed"} return err } maxDepth := GetMaxDepth(config.jsonSchema) log.Debugf("max depth of json schema: %d", maxDepth) if maxDepth > config.jsonSchemaMaxDepth { config.enableJsonSchemaValidation = false log.Infof("Json Schema depth exceeded: %d from %d , Json Schema validation will not be used.", maxDepth, config.jsonSchemaMaxDepth) } if config.enableJsonSchemaValidation { jsonSchemaStr := string(jsonSchemaBytes) config.compiler.AddResource(DEFAULT_SCHEMA, strings.NewReader(jsonSchemaStr)) // Test if the Json Schema is valid compile, err := config.compiler.Compile(DEFAULT_SCHEMA) if err != nil { log.Infof("Json Schema compile failed: %v", err) config.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()} config.compile = nil } else { config.compile = compile } } return nil } func (r *RequestContext) assembleReqBody(config PluginConfig) []byte { var reqBodystrut chatCompletionRequest json.Unmarshal(r.ReqBody, &reqBodystrut) content := gjson.ParseBytes(r.RespBody).Get(config.contentPath).String() jsonSchemaBytes, _ := json.Marshal(config.jsonSchema) jsonSchemaStr := string(jsonSchemaBytes) askQuestion := "Given the Json Schema: " + jsonSchemaStr + ", please help me convert the following content to a pure json: " + content askQuestion += "\n Do not respond other content except the pure json!!!!" reqBodystrut.Messages = append(r.HistoryMessages, []chatMessage{ { Role: "user", Content: askQuestion, }, }...) reqBody, _ := json.Marshal(reqBodystrut) return reqBody } func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, respBody []byte) { r.RespBody = respBody lastUserMessage := "" lastSystemMessage := "" var reqBodystrut chatCompletionRequest err := json.Unmarshal(reqBody, &reqBodystrut) if err != nil { log.Debugf("unmarshal reqBody failed: %v", err) } else { if len(reqBodystrut.Messages) != 0 { lastUserMessage = reqBodystrut.Messages[len(reqBodystrut.Messages)-1].Content } } var respBodystrut chatCompletionResponse err = json.Unmarshal(respBody, &respBodystrut) if err != nil { log.Debugf("unmarshal respBody failed: %v", err) } else { if len(respBodystrut.Choices) != 0 { lastSystemMessage = respBodystrut.Choices[len(respBodystrut.Choices)-1].Message.Content } } if lastUserMessage != "" { r.HistoryMessages = append(r.HistoryMessages, chatMessage{ Role: "user", Content: lastUserMessage, }) } if lastSystemMessage != "" { r.HistoryMessages = append(r.HistoryMessages, chatMessage{ Role: "system", Content: lastSystemMessage, }) } } func (r *RequestContext) SaveStrToHistMsg(log wrapper.Log, errMsg string) { r.HistoryMessages = append(r.HistoryMessages, chatMessage{ Role: "system", Content: errMsg, }) } func (c *PluginConfig) ValidateBody(body []byte) error { var respJsonStrct chatCompletionResponse err := json.Unmarshal(body, &respJsonStrct) if err != nil { c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "service unavailable: " + string(body)} return errors.New(c.rejectStruct.RejectMsg) } content := gjson.ParseBytes(body).Get(c.contentPath) if !content.Exists() { c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "response body does not contain the content: " + string(body)} return errors.New(c.rejectStruct.RejectMsg) } return nil } func (c *PluginConfig) ValidateJson(body []byte, log wrapper.Log) (string, error) { content := gjson.ParseBytes(body).Get(c.contentPath).String() // first extract json from response body if content == "" { log.Infof("response body does not contain the content") c.rejectStruct = RejectStruct{CONTENT_IS_EMPTY_CODE, "response body does not contain the content"} return "", errors.New(c.rejectStruct.RejectMsg) } jsonStr, err := c.ExtractJson(content) if err != nil { log.Infof("response body does not contain the valid json: %v", err.Error()) c.rejectStruct = RejectStruct{CANNOT_FIND_JSON_IN_RESPONSE_CODE, "response body does not contain the valid json: " + err.Error()} return "", errors.New(c.rejectStruct.RejectMsg) } if c.jsonSchema != nil && c.enableJsonSchemaValidation { compile, err := c.compiler.Compile(DEFAULT_SCHEMA) if err != nil { log.Infof("Json Schema compile failed: %v", err) c.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()} c.compile = nil } else { c.compile = compile } // validate the json err = c.compile.Validate(strings.NewReader(jsonStr)) if err != nil { log.Infof("response body does not match the Json Schema: %v", err) c.rejectStruct = RejectStruct{JSON_MISMATCH_SCHEMA_CODE, "response body does not match the Json Schema: " + err.Error()} return "", errors.New(c.rejectStruct.RejectMsg) } } c.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""} return jsonStr, nil } func (c *PluginConfig) ExtractJson(bodyStr string) (string, error) { // simply extract json from response body string startIndex := strings.Index(bodyStr, "{") endIndex := strings.LastIndex(bodyStr, "}") + 1 // if not found if startIndex == -1 || endIndex == -1 || startIndex >= endIndex { return "", errors.New("cannot find json in the response body") } jsonStr := bodyStr[startIndex:endIndex] // attempt to parse the JSON var result map[string]interface{} err := json.Unmarshal([]byte(jsonStr), &result) if err != nil { return "", err } return jsonStr, nil } func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, body []byte) { log.Infof("Final send: Code %d, Message %s, Body: %s", config.rejectStruct.RejectCode, config.rejectStruct.RejectMsg, string(body)) header := [][2]string{ {"Content-Type", "application/json"}, } if body != nil && config.enableContentDisposition { header = append(header, [2]string{"Content-Disposition", "attachment; filename=\"response.json\""}) } if config.rejectStruct.RejectCode != HTTP_STATUS_OK { proxywasm.SendHttpResponseWithDetail(HTTP_STATUS_INTERNAL_SERVER_ERROR, config.rejectStruct.GetShortMsg(), nil, config.rejectStruct.GetBytes(), -1) } else { proxywasm.SendHttpResponse(HTTP_STATUS_OK, header, body, -1) } } func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, retryCount int, requestContext *RequestContext) { // if retry count exceeds max retry count, return the response if retryCount >= config.maxRetry { log.Debugf("retry count exceeds max retry count") // report more useful error by appending the last of previous error message config.rejectStruct = RejectStruct{REACH_MAX_RETRY_COUNT_CODE, "retry count exceeds max retry count: " + config.rejectStruct.RejectMsg} sendResponse(ctx, config, log, nil) return } // recursively refine json config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.assembleReqBody(config), func(statusCode int, responseHeaders http.Header, responseBody []byte) { err := config.ValidateBody(responseBody) if err != nil { sendResponse(ctx, config, log, nil) return } retryCount++ requestContext.SaveBodyToHistMsg(log, requestContext.assembleReqBody(config), responseBody) log.Debugf("[retry request %d/%d] resp code: %d", retryCount, config.maxRetry, statusCode) validateJson, err := config.ValidateJson(responseBody, log) if err == nil { sendResponse(ctx, config, log, []byte(validateJson)) } else { requestContext.SaveStrToHistMsg(log, err.Error()) recursiveRefineJson(ctx, config, log, retryCount, requestContext) } }, uint32(config.serviceTimeout)) } func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { if config.rejectStruct.RejectCode != HTTP_STATUS_OK { sendResponse(ctx, config, log, nil) return types.ActionPause } // verify if the request is from this plugin extendHeaderValue, err := proxywasm.GetHttpRequestHeader(EXTEND_HEADER_KEY) if err == nil { fromThisPlugin, convErr := strconv.ParseBool(extendHeaderValue) if convErr != nil { log.Debugf("failed to parse header value as bool: %v", convErr) ctx.SetContext(FROM_THIS_PLUGIN_KEY, false) } if fromThisPlugin { ctx.SetContext(FROM_THIS_PLUGIN_KEY, true) return types.ActionContinue } } else { ctx.SetContext(FROM_THIS_PLUGIN_KEY, false) } path, err := proxywasm.GetHttpRequestHeader(":path") if err != nil { log.Infof("get request path failed: %v", err) path = "" } else { ctx.SetContext("path", path) } headers, err := proxywasm.GetHttpRequestHeaders() if err != nil { log.Infof("get request header failed: %v", err) } apiKey, err := proxywasm.GetHttpRequestHeader("Authorization") if err != nil { log.Infof("get request header failed: %v", err) apiKey = "" } if apiKey != "" { // remove the Authorization header proxywasm.RemoveHttpRequestHeader("Authorization") // remove the Authorization header from the headers for i, header := range headers { if header[0] == "Authorization" { headers = append(headers[:i], headers[i+1:]...) break } } } if config.apiKey != "" { log.Debugf("add Authorization header %s", "Bearer "+config.apiKey) headers = append(headers, [2]string{"Authorization", "Bearer " + config.apiKey}) } ctx.SetContext("headers", headers) return types.ActionContinue } func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { // if the request is from this plugin, continue the request fromThisPlugin, ok := ctx.GetContext(FROM_THIS_PLUGIN_KEY).(bool) if ok && fromThisPlugin { log.Debugf("detected buffer_request, sending request to AI service") return types.ActionContinue } var headers [][2]string if h, ok := ctx.GetContext("headers").([][2]string); ok { headers = append(h, [2]string{EXTEND_HEADER_KEY, "true"}) } else { log.Debugf("cannot get headers from context, use default headers") headers = [][2]string{ {"Content-Type", "application/json"}, {EXTEND_HEADER_KEY, "true"}, } } // if there is any error in the config, return the response directly if config.rejectStruct.RejectCode != HTTP_STATUS_OK { sendResponse(ctx, config, log, nil) return types.ActionContinue } var path string if path, ok := ctx.GetContext("path").(string); ok { log.Debugf("use path: %s", path) } else { log.Debugf("cannot get path from context, use default path") path = "/v1/chat/completions" } if config.servicePath != "" { log.Debugf("use base path: %s", config.servicePath) path = config.servicePath } requestContext := &RequestContext{ Path: path, ReqHeaders: headers, ReqBody: body, } config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.ReqBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { err := config.ValidateBody(responseBody) if err != nil { sendResponse(ctx, config, log, nil) return } requestContext.SaveBodyToHistMsg(log, body, responseBody) log.Debugf("[first request] resp code: %d", statusCode) validateJson, err := config.ValidateJson(responseBody, log) if err == nil { sendResponse(ctx, config, log, []byte(validateJson)) return } else { retryCount := 0 requestContext.SaveStrToHistMsg(log, err.Error()) recursiveRefineJson(ctx, config, log, retryCount, requestContext) } }, uint32(config.serviceTimeout)) return types.ActionPause }