plugins/wasm-go/extensions/ai-statistics/main.go (448 lines of code) (raw):

package main import ( "bytes" "encoding/json" "errors" "fmt" "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" ) func main() { wrapper.SetCtx( "ai-statistics", wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingBody), wrapper.ProcessResponseBodyBy(onHttpResponseBody), ) } const ( defaultMaxBodyBytes uint32 = 100 * 1024 * 1024 // Context consts StatisticsRequestStartTime = "ai-statistics-request-start-time" StatisticsFirstTokenTime = "ai-statistics-first-token-time" CtxGeneralAtrribute = "attributes" CtxLogAtrribute = "logAttributes" CtxStreamingBodyBuffer = "streamingBodyBuffer" RouteName = "route" ClusterName = "cluster" APIName = "api" ConsumerKey = "x-mse-consumer" // Source Type FixedValue = "fixed_value" RequestHeader = "request_header" RequestBody = "request_body" ResponseHeader = "response_header" ResponseStreamingBody = "response_streaming_body" ResponseBody = "response_body" // Inner metric & log attributes Model = "model" InputToken = "input_token" OutputToken = "output_token" LLMFirstTokenDuration = "llm_first_token_duration" LLMServiceDuration = "llm_service_duration" LLMDurationCount = "llm_duration_count" LLMStreamDurationCount = "llm_stream_duration_count" ResponseType = "response_type" ChatID = "chat_id" ChatRound = "chat_round" // Inner span attributes ArmsSpanKind = "gen_ai.span.kind" ArmsModelName = "gen_ai.model_name" ArmsRequestModel = "gen_ai.request.model" ArmsInputToken = "gen_ai.usage.input_tokens" ArmsOutputToken = "gen_ai.usage.output_tokens" ArmsTotalToken = "gen_ai.usage.total_tokens" // Extract Rule RuleFirst = "first" RuleReplace = "replace" RuleAppend = "append" ) // TracingSpan is the tracing span configuration. type Attribute struct { Key string `json:"key"` ValueSource string `json:"value_source"` Value string `json:"value"` DefaultValue string `json:"default_value,omitempty"` Rule string `json:"rule,omitempty"` ApplyToLog bool `json:"apply_to_log,omitempty"` ApplyToSpan bool `json:"apply_to_span,omitempty"` } type AIStatisticsConfig struct { // Metrics // TODO: add more metrics in Gauge and Histogram format counterMetrics map[string]proxywasm.MetricCounter // Attributes to be recorded in log & span attributes []Attribute // If there exist attributes extracted from streaming body, chunks should be buffered shouldBufferStreamingBody bool } func generateMetricName(route, cluster, model, consumer, metricName string) string { return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName) } func getRouteName() (string, error) { if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil { return "-", err } else { return string(raw), nil } } func getAPIName() (string, error) { if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil { return "-", err } else { parts := strings.Split(string(raw), "@") if len(parts) != 5 { return "-", errors.New("not api type") } else { return strings.Join(parts[:3], "@"), nil } } } func getClusterName() (string, error) { if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil { return "-", err } else { return string(raw), nil } } func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) { if inc == 0 { return } counter, ok := config.counterMetrics[metricName] if !ok { counter = proxywasm.DefineCounterMetric(metricName) config.counterMetrics[metricName] = counter } counter.Increment(inc) } func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrapper.Log) error { // Parse tracing span attributes setting. attributeConfigs := configJson.Get("attributes").Array() config.attributes = make([]Attribute, len(attributeConfigs)) for i, attributeConfig := range attributeConfigs { attribute := Attribute{} err := json.Unmarshal([]byte(attributeConfig.Raw), &attribute) if err != nil { log.Errorf("parse config failed, %v", err) return err } if attribute.ValueSource == ResponseStreamingBody { config.shouldBufferStreamingBody = true } if attribute.Rule != "" && attribute.Rule != RuleFirst && attribute.Rule != RuleReplace && attribute.Rule != RuleAppend { return errors.New("value of rule must be one of [nil, first, replace, append]") } config.attributes[i] = attribute } // Metric settings config.counterMetrics = make(map[string]proxywasm.MetricCounter) return nil } func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action { route, _ := getRouteName() cluster, _ := getClusterName() api, api_error := getAPIName() if api_error == nil { route = api } ctx.SetContext(RouteName, route) ctx.SetContext(ClusterName, cluster) ctx.SetUserAttribute(APIName, api) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" { ctx.SetContext(ConsumerKey, consumer) } hasRequestBody := wrapper.HasRequestBody() if hasRequestBody { _ = proxywasm.RemoveHttpRequestHeader("Content-Length") ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) } // Set user defined log & span attributes which type is fixed_value setAttributeBySource(ctx, config, FixedValue, nil, log) // Set user defined log & span attributes which type is request_header setAttributeBySource(ctx, config, RequestHeader, nil, log) // Set span attributes for ARMS. setSpanAttribute(ArmsSpanKind, "LLM", log) return types.ActionContinue } func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { // Set user defined log & span attributes. setAttributeBySource(ctx, config, RequestBody, body, log) // Set span attributes for ARMS. requestModel := gjson.GetBytes(body, "model").String() if requestModel == "" { requestModel = "UNKNOWN" } setSpanAttribute(ArmsRequestModel, requestModel, log) // Set the number of conversation rounds if gjson.GetBytes(body, "messages").Exists() { userPromptCount := 0 for _, msg := range gjson.GetBytes(body, "messages").Array() { if msg.Get("role").String() == "user" { userPromptCount += 1 } } ctx.SetUserAttribute(ChatRound, userPromptCount) } // Write log ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) return types.ActionContinue } func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if !strings.Contains(contentType, "text/event-stream") { ctx.BufferResponseBody() } // Set user defined log & span attributes. setAttributeBySource(ctx, config, ResponseHeader, nil, log) return types.ActionContinue } func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte { // Buffer stream body for record log & span attributes if config.shouldBufferStreamingBody { var streamingBodyBuffer []byte streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte) if !ok { streamingBodyBuffer = data } else { streamingBodyBuffer = append(streamingBodyBuffer, data...) } ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer) } ctx.SetUserAttribute(ResponseType, "stream") chatID := gjson.GetBytes(data, "id").String() if chatID != "" { ctx.SetUserAttribute(ChatID, chatID) } // Get requestStartTime from http context requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64) if !ok { log.Error("failed to get requestStartTime from http context") return data } // If this is the first chunk, record first token duration metric and span attribute if ctx.GetContext(StatisticsFirstTokenTime) == nil { firstTokenTime := time.Now().UnixMilli() ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime) ctx.SetUserAttribute(LLMFirstTokenDuration, firstTokenTime-requestStartTime) } // Set information about this request if model, inputToken, outputToken, ok := getUsage(data); ok { ctx.SetUserAttribute(Model, model) ctx.SetUserAttribute(InputToken, inputToken) ctx.SetUserAttribute(OutputToken, outputToken) // Set span attributes for ARMS. setSpanAttribute(ArmsModelName, model, log) setSpanAttribute(ArmsInputToken, inputToken, log) setSpanAttribute(ArmsOutputToken, outputToken, log) setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) } // If the end of the stream is reached, record metrics/logs/spans. if endOfStream { responseEndTime := time.Now().UnixMilli() ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) // Set user defined log & span attributes. if config.shouldBufferStreamingBody { streamingBodyBuffer, ok := ctx.GetContext(CtxStreamingBodyBuffer).([]byte) if !ok { return data } setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log) } // Write log ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics writeMetric(ctx, config, log) } return data } func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { // Get requestStartTime from http context requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64) responseEndTime := time.Now().UnixMilli() ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) ctx.SetUserAttribute(ResponseType, "normal") chatID := gjson.GetBytes(body, "id").String() if chatID != "" { ctx.SetUserAttribute(ChatID, chatID) } // Set information about this request if model, inputToken, outputToken, ok := getUsage(body); ok { ctx.SetUserAttribute(Model, model) ctx.SetUserAttribute(InputToken, inputToken) ctx.SetUserAttribute(OutputToken, outputToken) // Set span attributes for ARMS. setSpanAttribute(ArmsModelName, model, log) setSpanAttribute(ArmsInputToken, inputToken, log) setSpanAttribute(ArmsOutputToken, outputToken, log) setSpanAttribute(ArmsTotalToken, inputToken+outputToken, log) } // Set user defined log & span attributes. setAttributeBySource(ctx, config, ResponseBody, body, log) // Write log ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics writeMetric(ctx, config, log) return types.ActionContinue } func unifySSEChunk(data []byte) []byte { data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n")) data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n")) return data } func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsage int64, ok bool) { chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n")) for _, chunk := range chunks { // the feature strings are used to identify the usage data, like: // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} if !bytes.Contains(chunk, []byte("prompt_tokens")) { continue } if !bytes.Contains(chunk, []byte("completion_tokens")) { continue } modelObj := gjson.GetBytes(chunk, "model") if modelObj.Exists() { model = modelObj.String() } else { model = "unknown" } inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") if inputTokenObj.Exists() && outputTokenObj.Exists() { inputTokenUsage = inputTokenObj.Int() outputTokenUsage = outputTokenObj.Int() ok = true return } } return } // fetches the tracing span value from the specified source. func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) { for _, attribute := range config.attributes { var key string var value interface{} if source == attribute.ValueSource { key = attribute.Key switch source { case FixedValue: value = attribute.Value case RequestHeader: value, _ = proxywasm.GetHttpRequestHeader(attribute.Value) case RequestBody: value = gjson.GetBytes(body, attribute.Value).Value() case ResponseHeader: value, _ = proxywasm.GetHttpResponseHeader(attribute.Value) case ResponseStreamingBody: value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) case ResponseBody: value = gjson.GetBytes(body, attribute.Value).Value() default: } if (value == nil || value == "") && attribute.DefaultValue != "" { value = attribute.DefaultValue } log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value) if attribute.ApplyToLog { ctx.SetUserAttribute(key, value) } // for metrics if key == Model || key == InputToken || key == OutputToken { ctx.SetContext(key, value) } if attribute.ApplyToSpan { setSpanAttribute(key, value, log) } } } } func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} { chunks := bytes.Split(bytes.TrimSpace(unifySSEChunk(data)), []byte("\n\n")) var value interface{} if rule == RuleFirst { for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { value = jsonObj.Value() break } } } else if rule == RuleReplace { for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { value = jsonObj.Value() } } } else if rule == RuleAppend { // extract llm response var strValue string for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { strValue += jsonObj.String() } } value = strValue } else { log.Errorf("unsupported rule type: %s", rule) } return value } // Set the tracing span with value. func setSpanAttribute(key string, value interface{}, log wrapper.Log) { if value != "" { traceSpanTag := wrapper.TraceSpanTagPrefix + key if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil { log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e) } } else { log.Debugf("failed to write span attribute [%s], because it's value is empty") } } func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) { // Generate usage metrics var ok bool var route, cluster, model string var inputToken, outputToken uint64 consumer := ctx.GetStringContext(ConsumerKey, "none") route, ok = ctx.GetContext(RouteName).(string) if !ok { log.Warnf("RouteName typd assert failed, skip metric record") return } cluster, ok = ctx.GetContext(ClusterName).(string) if !ok { log.Warnf("ClusterName typd assert failed, skip metric record") return } if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil { log.Warnf("get usage information failed, skip metric record") return } model, ok = ctx.GetUserAttribute(Model).(string) if !ok { log.Warnf("Model typd assert failed, skip metric record") return } inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken)) if !ok { log.Warnf("InputToken typd assert failed, skip metric record") return } outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken)) if !ok { log.Warnf("OutputToken typd assert failed, skip metric record") return } if inputToken == 0 || outputToken == 0 { log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") return } config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken) config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken) // Generate duration metrics var llmFirstTokenDuration, llmServiceDuration uint64 // Is stream response if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil { llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration)) if !ok { log.Warnf("LLMFirstTokenDuration typd assert failed") return } config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1) } if ctx.GetUserAttribute(LLMServiceDuration) != nil { llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration)) if !ok { log.Warnf("LLMServiceDuration typd assert failed") return } config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1) } } func convertToUInt(val interface{}) (uint64, bool) { switch v := val.(type) { case float32: return uint64(v), true case float64: return uint64(v), true case int32: return uint64(v), true case int64: return uint64(v), true case uint32: return uint64(v), true case uint64: return v, true default: return 0, false } }