plugins/wasm-go/extensions/ai-proxy/provider/qwen.go (600 lines of code) (raw):

package provider import ( "encoding/json" "errors" "fmt" "math" "net/http" "reflect" "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/log" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) // qwenProvider is the provider for Qwen service. const ( qwenResultFormatMessage = "message" qwenDefaultDomain = "dashscope.aliyuncs.com" qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation" qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding" qwenChatCompatiblePath = "/compatible-mode/v1/chat/completions" qwenTextEmbeddingCompatiblePath = "/compatible-mode/v1/embeddings" qwenBailianPath = "/api/v1/apps" qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation" qwenTopPMin = 0.000001 qwenTopPMax = 0.999999 qwenDummySystemMessageContent = "You are a helpful assistant." qwenLongModelName = "qwen-long" qwenVlModelPrefixName = "qwen-vl" ) type qwenProviderInitializer struct { } func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error { if len(config.qwenFileIds) != 0 && config.context != nil { return errors.New("qwenFileIds and context cannot be configured at the same time") } if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") } return nil } func (m *qwenProviderInitializer) DefaultCapabilities(qwenEnableCompatible bool) map[string]string { if qwenEnableCompatible { return map[string]string{ string(ApiNameChatCompletion): qwenChatCompatiblePath, string(ApiNameEmbeddings): qwenTextEmbeddingCompatiblePath, } } else { return map[string]string{ string(ApiNameChatCompletion): qwenChatCompletionPath, string(ApiNameEmbeddings): qwenTextEmbeddingPath, } } } func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { config.setDefaultCapabilities(m.DefaultCapabilities(config.qwenEnableCompatible)) return &qwenProvider{ config: config, contextCache: createContextCache(&config), }, nil } type qwenProvider struct { config ProviderConfig contextCache *contextCache } func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { if m.config.qwenDomain != "" { util.OverwriteRequestHostHeader(headers, m.config.qwenDomain) } else { util.OverwriteRequestHostHeader(headers, qwenDefaultDomain) } util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) if !m.config.IsOriginal() { util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) } } func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { if m.config.qwenEnableCompatible { if gjson.GetBytes(body, "model").Exists() { rawModel := gjson.GetBytes(body, "model").String() mappedModel := getMappedModel(rawModel, m.config.modelMapping) newBody, err := sjson.SetBytes(body, "model", mappedModel) if err != nil { log.Errorf("Replace model error: %v", err) return newBody, err } return newBody, nil } return body, nil } switch apiName { case ApiNameChatCompletion: return m.onChatCompletionRequestBody(ctx, body, headers) case ApiNameEmbeddings: return m.onEmbeddingsRequestBody(ctx, body) default: return m.config.defaultTransformRequestBody(ctx, apiName, body) } } func (m *qwenProvider) GetProviderType() string { return providerTypeQwen } func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { m.config.handleRequestHeaders(m, ctx, apiName) if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() return nil } return nil } func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { if !m.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body) } func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { request := &chatCompletionRequest{} err := m.config.parseRequestAndMapModel(ctx, request, body) if err != nil { return nil, err } // Use the qwen multimodal model generation API if strings.HasPrefix(request.Model, qwenVlModelPrefixName) { util.OverwriteRequestPathHeader(headers, qwenMultimodalGenerationPath) } streaming := request.Stream if streaming { headers.Set("Accept", "text/event-stream") headers.Set("X-DashScope-SSE", "enable") } else { headers.Set("Accept", "*/*") headers.Del("X-DashScope-SSE") } return m.buildQwenTextGenerationRequest(ctx, request, streaming) } func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { request := &embeddingsRequest{} if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil { return nil, err } qwenRequest, err := m.buildQwenTextEmbeddingRequest(request) if err != nil { return nil, err } return json.Marshal(qwenRequest) } func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error) { if m.config.qwenEnableCompatible || name != ApiNameChatCompletion { return nil, nil } incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false) qwenResponse := &qwenTextGenResponse{} if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil { log.Errorf("unable to unmarshal Qwen response: %v", err) return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err) } var outputEvents []StreamEvent responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming) for _, response := range responses { responseBody, err := json.Marshal(response) if err != nil { log.Errorf("unable to marshal response: %v", err) return nil, fmt.Errorf("unable to marshal response: %v", err) } modifiedEvent := event modifiedEvent.Data = string(responseBody) outputEvents = append(outputEvents, modifiedEvent) } return outputEvents, nil } func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { if m.config.qwenEnableCompatible { return body, nil } if apiName == ApiNameChatCompletion { return m.onChatCompletionResponseBody(ctx, body) } if apiName == ApiNameEmbeddings { return m.onEmbeddingsResponseBody(ctx, body) } if m.config.isSupportedAPI(apiName) { return body, nil } return nil, errUnsupportedApiName } func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { qwenResponse := &qwenTextGenResponse{} if err := json.Unmarshal(body, qwenResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err) } response := m.buildChatCompletionResponse(ctx, qwenResponse) return json.Marshal(response) } func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) { qwenResponse := &qwenTextEmbeddingResponse{} if err := json.Unmarshal(body, qwenResponse); err != nil { return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err) } response := m.buildEmbeddingsResponse(ctx, qwenResponse) return json.Marshal(response) } func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) { messages := make([]qwenMessage, 0, len(origRequest.Messages)) for i := range origRequest.Messages { messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i])) } request := &qwenTextGenRequest{ Model: origRequest.Model, Input: qwenTextGenInput{ Messages: messages, }, Parameters: qwenTextGenParameters{ ResultFormat: qwenResultFormatMessage, MaxTokens: origRequest.MaxTokens, N: origRequest.N, Seed: origRequest.Seed, Temperature: origRequest.Temperature, TopP: math.Max(qwenTopPMin, math.Min(origRequest.TopP, qwenTopPMax)), IncrementalOutput: streaming && (origRequest.Tools == nil || len(origRequest.Tools) == 0), EnableSearch: m.config.qwenEnableSearch, Tools: origRequest.Tools, }, } if streaming { ctx.SetContext(ctxKeyIncrementalStreaming, request.Parameters.IncrementalOutput) } if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName { builder := strings.Builder{} for _, fileId := range m.config.qwenFileIds { if builder.Len() != 0 { builder.WriteRune(',') } builder.WriteString("fileid://") builder.WriteString(fileId) } body, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("unable to marshal request: %v", err) } return m.insertHttpContextMessage(body, builder.String(), true) } return json.Marshal(request) } func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse { choices := make([]chatCompletionChoice, 0, len(qwenResponse.Output.Choices)) for _, qwenChoice := range qwenResponse.Output.Choices { message := qwenMessageToChatMessage(qwenChoice.Message, m.config.reasoningContentMode) choices = append(choices, chatCompletionChoice{ Message: &message, FinishReason: qwenChoice.FinishReason, }) } return &chatCompletionResponse{ Id: qwenResponse.RequestId, Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletion, Choices: choices, Usage: usage{ PromptTokens: qwenResponse.Usage.InputTokens, CompletionTokens: qwenResponse.Usage.OutputTokens, TotalTokens: qwenResponse.Usage.TotalTokens, }, } } func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool) []*chatCompletionResponse { baseMessage := chatCompletionResponse{ Id: qwenResponse.RequestId, Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), Choices: make([]chatCompletionChoice, 0), SystemFingerprint: "", Object: objectChatCompletionChunk, } responses := make([]*chatCompletionResponse, 0) qwenChoice := qwenResponse.Output.Choices[0] // Yes, Qwen uses a string "null" as null. finished := qwenChoice.FinishReason != "" && qwenChoice.FinishReason != "null" message := qwenChoice.Message reasoningContentMode := m.config.reasoningContentMode log.Warnf("incrementalStreaming: %v", incrementalStreaming) deltaContentMessage := &chatMessage{Role: message.Role, Content: message.Content, ReasoningContent: message.ReasoningContent} deltaToolCallsMessage := &chatMessage{Role: message.Role, ToolCalls: append([]toolCall{}, message.ToolCalls...)} if incrementalStreaming { deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode) } else { for _, tc := range message.ToolCalls { if tc.Function.Arguments == "" && !finished { // We don't push any tool call until its arguments are available. return nil } } if pushedMessage, ok := ctx.GetContext(ctxKeyPushedMessage).(qwenMessage); ok { if message.Content == "" { message.Content = pushedMessage.Content } else if message.IsStringContent() { deltaContentMessage.Content = util.StripPrefix(deltaContentMessage.StringContent(), pushedMessage.StringContent()) } else if strings.HasPrefix(baseMessage.Model, qwenVlModelPrefixName) { // Use the Qwen multimodal model generation API deltaContentList, ok := deltaContentMessage.Content.([]qwenVlMessageContent) if !ok { log.Warnf("unexpected deltaContentMessage content type: %T", deltaContentMessage.Content) } else { pushedContentList, ok := pushedMessage.Content.([]qwenVlMessageContent) if !ok { log.Warnf("unexpected pushedMessage content type: %T", pushedMessage.Content) } else { for i, content := range deltaContentList { if i >= len(pushedContentList) { break } pushedText := pushedContentList[i].Text content.Text = util.StripPrefix(content.Text, pushedText) deltaContentList[i] = content } } } } if message.ReasoningContent == "" { message.ReasoningContent = pushedMessage.ReasoningContent } else { deltaContentMessage.ReasoningContent = util.StripPrefix(deltaContentMessage.ReasoningContent, pushedMessage.ReasoningContent) } deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode) if len(deltaToolCallsMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil { for i, tc := range deltaToolCallsMessage.ToolCalls { if i >= len(pushedMessage.ToolCalls) { break } pushedFunction := pushedMessage.ToolCalls[i].Function tc.Function.Id = util.StripPrefix(tc.Function.Id, pushedFunction.Id) tc.Function.Name = util.StripPrefix(tc.Function.Name, pushedFunction.Name) tc.Function.Arguments = util.StripPrefix(tc.Function.Arguments, pushedFunction.Arguments) deltaToolCallsMessage.ToolCalls[i] = tc } } } ctx.SetContext(ctxKeyPushedMessage, message) } if !deltaContentMessage.IsEmpty() { response := *&baseMessage response.Choices = append(response.Choices, chatCompletionChoice{Delta: deltaContentMessage}) responses = append(responses, &response) } if !deltaToolCallsMessage.IsEmpty() { response := *&baseMessage response.Choices = append(response.Choices, chatCompletionChoice{Delta: deltaToolCallsMessage}) responses = append(responses, &response) } if finished { finishResponse := *&baseMessage finishResponse.Choices = append(finishResponse.Choices, chatCompletionChoice{Delta: &chatMessage{}, FinishReason: qwenChoice.FinishReason}) usageResponse := *&baseMessage usageResponse.Choices = []chatCompletionChoice{{Delta: &chatMessage{}}} usageResponse.Usage = usage{ PromptTokens: qwenResponse.Usage.InputTokens, CompletionTokens: qwenResponse.Usage.OutputTokens, TotalTokens: qwenResponse.Usage.TotalTokens, } responses = append(responses, &finishResponse, &usageResponse) } return responses } func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) { request := &qwenTextGenRequest{} if err := json.Unmarshal(body, request); err != nil { return nil, fmt.Errorf("unable to unmarshal request: %v", err) } fileMessage := qwenMessage{ Role: roleSystem, Content: content, } var firstNonSystemMessageIndex int messages := request.Input.Messages if messages != nil { for i, message := range request.Input.Messages { if message.Role != roleSystem { firstNonSystemMessageIndex = i break } } } if firstNonSystemMessageIndex == 0 { request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...) } else if !onlyOneSystemBeforeFile { request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...) } else { builder := strings.Builder{} for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] { if builder.Len() != 0 { builder.WriteString("\n") } builder.WriteString(message.StringContent()) } request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...) firstNonSystemMessageIndex = 1 } if firstNonSystemMessageIndex == 0 { // The context message cannot come first. We need to add another dummy system message before it. request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...) } return json.Marshal(request) } func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) { responseBuilder.WriteString(streamDataItemKey) responseBuilder.WriteString(event.Data) responseBuilder.WriteString("\n\n") } func (m *qwenProvider) buildQwenTextEmbeddingRequest(request *embeddingsRequest) (*qwenTextEmbeddingRequest, error) { var texts []string if str, isString := request.Input.(string); isString { texts = []string{str} } else if strs, isArray := request.Input.([]interface{}); isArray { texts = make([]string, 0, len(strs)) for _, item := range strs { if str, isString := item.(string); isString { texts = append(texts, str) } else { return nil, errors.New("unsupported input type in array: " + reflect.TypeOf(item).String()) } } } else { return nil, errors.New("unsupported input type: " + reflect.TypeOf(request.Input).String()) } return &qwenTextEmbeddingRequest{ Model: request.Model, Input: qwenTextEmbeddingInput{ Texts: texts, }, }, nil } func (m *qwenProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextEmbeddingResponse) *embeddingsResponse { data := make([]embedding, 0, len(qwenResponse.Output.Embeddings)) for _, qwenEmbedding := range qwenResponse.Output.Embeddings { data = append(data, embedding{ Object: "embedding", Index: qwenEmbedding.TextIndex, Embedding: qwenEmbedding.Embedding, }) } return &embeddingsResponse{ Object: "list", Data: data, Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), Usage: usage{ PromptTokens: qwenResponse.Usage.TotalTokens, TotalTokens: qwenResponse.Usage.TotalTokens, }, } } type qwenTextGenRequest struct { Model string `json:"model"` Input qwenTextGenInput `json:"input"` Parameters qwenTextGenParameters `json:"parameters,omitempty"` } type qwenTextGenInput struct { Messages []qwenMessage `json:"messages"` } type qwenTextGenParameters struct { ResultFormat string `json:"result_format,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` RepetitionPenalty float64 `json:"repetition_penalty,omitempty"` N int `json:"n,omitempty"` Seed int `json:"seed,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"` EnableSearch bool `json:"enable_search,omitempty"` Tools []tool `json:"tools,omitempty"` } type qwenTextGenResponse struct { RequestId string `json:"request_id"` Output qwenTextGenOutput `json:"output"` Usage qwenUsage `json:"usage"` } type qwenTextGenOutput struct { FinishReason string `json:"finish_reason"` Choices []qwenTextGenChoice `json:"choices"` } type qwenTextGenChoice struct { FinishReason string `json:"finish_reason"` Message qwenMessage `json:"message"` } type qwenUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` TotalTokens int `json:"total_tokens"` } type qwenMessage struct { Name string `json:"name,omitempty"` Role string `json:"role"` Content any `json:"content"` ReasoningContent string `json:"reasoning_content,omitempty"` ToolCalls []toolCall `json:"tool_calls,omitempty"` } type qwenVlMessageContent struct { Image string `json:"image,omitempty"` Text string `json:"text,omitempty"` } type qwenTextEmbeddingRequest struct { Model string `json:"model"` Input qwenTextEmbeddingInput `json:"input"` Parameters qwenTextEmbeddingParameters `json:"parameters,omitempty"` } type qwenTextEmbeddingInput struct { Texts []string `json:"texts"` } type qwenTextEmbeddingParameters struct { TextType string `json:"text_type,omitempty"` } type qwenTextEmbeddingResponse struct { RequestId string `json:"request_id"` Output qwenTextEmbeddingOutput `json:"output"` Usage qwenUsage `json:"usage"` } type qwenTextEmbeddingOutput struct { RequestId string `json:"request_id"` Embeddings []qwenTextEmbeddings `json:"embeddings"` } type qwenTextEmbeddings struct { TextIndex int `json:"text_index"` Embedding []float64 `json:"embedding"` } func qwenMessageToChatMessage(qwenMessage qwenMessage, reasoningContentMode string) chatMessage { msg := chatMessage{ Name: qwenMessage.Name, Role: qwenMessage.Role, Content: qwenMessage.Content, ReasoningContent: qwenMessage.ReasoningContent, ToolCalls: qwenMessage.ToolCalls, } msg.handleNonStreamingReasoningContent(reasoningContentMode) return msg } func (m *qwenMessage) IsStringContent() bool { _, ok := m.Content.(string) return ok } func (m *qwenMessage) StringContent() string { content, ok := m.Content.(string) if ok { return content } contentList, ok := m.Content.([]any) if ok { var contentStr string for _, contentItem := range contentList { contentMap, ok := contentItem.(map[string]any) if !ok { continue } if text, ok := contentMap["text"].(string); ok { contentStr += text } } return contentStr } return "" } func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage { if chatMessage.IsStringContent() { return qwenMessage{ Name: chatMessage.Name, Role: chatMessage.Role, Content: chatMessage.StringContent(), ToolCalls: chatMessage.ToolCalls, } } else { var contents []qwenVlMessageContent openaiContent := chatMessage.ParseContent() for _, part := range openaiContent { var content qwenVlMessageContent if part.Type == contentTypeText { content.Text = part.Text } else if part.Type == contentTypeImageUrl { content.Image = part.ImageUrl.Url } contents = append(contents, content) } return qwenMessage{ Name: chatMessage.Name, Role: chatMessage.Role, Content: contents, ToolCalls: chatMessage.ToolCalls, } } } func (m *qwenProvider) GetApiName(path string) ApiName { switch { case strings.Contains(path, qwenChatCompletionPath), strings.Contains(path, qwenMultimodalGenerationPath), strings.Contains(path, qwenBailianPath), strings.Contains(path, qwenChatCompatiblePath): return ApiNameChatCompletion case strings.Contains(path, qwenTextEmbeddingPath), strings.Contains(path, qwenTextEmbeddingCompatiblePath): return ApiNameEmbeddings default: return "" } }