plugins/wasm-go/extensions/ai-search/main.go (649 lines of code) (raw):
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"bytes"
_ "embed"
"errors"
"fmt"
"net/http"
"strings"
"time"
"unicode"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/elasticsearch"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/google"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/quark"
)
type SearchRewrite struct {
client wrapper.HttpClient
url string
apiKey string
modelName string
timeoutMillisecond uint32
prompt string
promptTemplate string // Original prompt template before replacing placeholders
maxCount int
}
type Config struct {
engine []engine.SearchEngine
promptTemplate string
referenceFormat string
defaultLanguage string
needReference bool
referenceLocation string // "head" or "tail"
searchRewrite *SearchRewrite
defaultEnable bool
}
const (
DEFAULT_MAX_BODY_BYTES uint32 = 100 * 1024 * 1024
)
//go:embed prompts/full.md
var fullSearchPrompts string
//go:embed prompts/arxiv.md
var arxivSearchPrompts string
//go:embed prompts/internet.md
var internetSearchPrompts string
//go:embed prompts/private.md
var privateSearchPrompts string
//go:embed prompts/chinese-internet.md
var chineseInternetSearchPrompts string
func main() {
wrapper.SetCtx(
"ai-search",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
)
}
func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
config.defaultEnable = true // Default to true if not specified
if json.Get("defaultEnable").Exists() {
config.defaultEnable = json.Get("defaultEnable").Bool()
}
config.needReference = json.Get("needReference").Bool()
if config.needReference {
config.referenceFormat = json.Get("referenceFormat").String()
if config.referenceFormat == "" {
config.referenceFormat = "**References:**\n%s"
} else if !strings.Contains(config.referenceFormat, "%s") {
return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat)
}
config.referenceLocation = json.Get("referenceLocation").String()
if config.referenceLocation == "" {
config.referenceLocation = "head" // Default to head if not specified
} else if config.referenceLocation != "head" && config.referenceLocation != "tail" {
return fmt.Errorf("invalid referenceLocation:%s, must be 'head' or 'tail'", config.referenceLocation)
}
}
config.defaultLanguage = json.Get("defaultLang").String()
config.promptTemplate = json.Get("promptTemplate").String()
if config.promptTemplate == "" {
if config.needReference {
config.promptTemplate = `# 以下内容是基于用户发送的消息的搜索结果:
{search_results}
在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的,X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[3][5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。
在回答时,请注意以下几点:
- 今天是北京时间:{cur_date}。
- 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
- 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。
- 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[3][5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
- 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
# 用户消息为:
{question}`
} else {
config.promptTemplate = `# 以下内容是基于用户发送的消息的搜索结果:
{search_results}
在我给你的搜索结果中,每个结果都是[webpage begin]...[webpage end]格式的。
在回答时,请注意以下几点:
- 今天是北京时间:{cur_date}。
- 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
- 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内。如非必要,不要主动告诉用户搜索结果未提供的内容。
- 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
- 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关网页来回答,但回答中不要给出网页的引用来源。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
# 用户消息为:
{question}`
}
}
if !strings.Contains(config.promptTemplate, "{search_results}") ||
!strings.Contains(config.promptTemplate, "{question}") {
return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate)
}
var internetExists, privateExists, arxivExists bool
var onlyQuark bool = true
for _, e := range json.Get("searchFrom").Array() {
switch e.Get("type").String() {
case "bing":
searchEngine, err := bing.NewBingSearch(&e)
if err != nil {
return fmt.Errorf("bing search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "google":
searchEngine, err := google.NewGoogleSearch(&e)
if err != nil {
return fmt.Errorf("google search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "arxiv":
searchEngine, err := arxiv.NewArxivSearch(&e)
if err != nil {
return fmt.Errorf("arxiv search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
arxivExists = true
onlyQuark = false
case "elasticsearch":
searchEngine, err := elasticsearch.NewElasticsearchSearch(&e, config.needReference)
if err != nil {
return fmt.Errorf("elasticsearch search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
privateExists = true
onlyQuark = false
case "quark":
searchEngine, err := quark.NewQuarkSearch(&e)
if err != nil {
return fmt.Errorf("quark search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
internetExists = true
default:
return fmt.Errorf("unkown search engine:%s", e.Get("type").String())
}
}
searchRewriteJson := json.Get("searchRewrite")
if searchRewriteJson.Exists() {
searchRewrite := &SearchRewrite{}
llmServiceName := searchRewriteJson.Get("llmServiceName").String()
if llmServiceName == "" {
return errors.New("llm_service_name not found")
}
llmServicePort := searchRewriteJson.Get("llmServicePort").Int()
if llmServicePort == 0 {
return errors.New("llmServicePort not found")
}
searchRewrite.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: llmServiceName,
Port: llmServicePort,
})
llmApiKey := searchRewriteJson.Get("llmApiKey").String()
searchRewrite.apiKey = llmApiKey
llmUrl := searchRewriteJson.Get("llmUrl").String()
if llmUrl == "" {
return errors.New("llmUrl not found")
}
searchRewrite.url = llmUrl
llmModelName := searchRewriteJson.Get("llmModelName").String()
if llmModelName == "" {
return errors.New("llmModelName not found")
}
searchRewrite.modelName = llmModelName
llmTimeout := searchRewriteJson.Get("timeoutMillisecond").Uint()
if llmTimeout == 0 {
llmTimeout = 30000
}
searchRewrite.timeoutMillisecond = uint32(llmTimeout)
maxCount := searchRewriteJson.Get("maxCount").Int()
if maxCount == 0 {
maxCount = 3 // Default value
}
searchRewrite.maxCount = int(maxCount)
// The consideration here is that internet searches are generally available, but arxiv and private sources may not be.
if arxivExists {
if privateExists {
// private + internet + arxiv
searchRewrite.prompt = fullSearchPrompts
} else {
// internet + arxiv
searchRewrite.prompt = arxivSearchPrompts
}
} else if privateExists {
// private + internet
searchRewrite.prompt = privateSearchPrompts
} else if internetExists {
// only internet
if onlyQuark {
// When only quark is used, use chinese-internet.md
searchRewrite.prompt = chineseInternetSearchPrompts
} else {
searchRewrite.prompt = internetSearchPrompts
}
}
// Store the original prompt template before replacing placeholders
searchRewrite.promptTemplate = searchRewrite.prompt
// Replace {max_count} placeholder in the prompt with the configured value
searchRewrite.prompt = strings.Replace(searchRewrite.prompt, "{max_count}", fmt.Sprintf("%d", searchRewrite.maxCount), -1)
config.searchRewrite = searchRewrite
}
if len(config.engine) == 0 {
return fmt.Errorf("no avaliable search engine found")
}
log.Debugf("ai search enabled, config: %#v", config)
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action {
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
// The request does not have a body.
if contentType == "" {
return types.ActionContinue
}
if !strings.Contains(contentType, "application/json") {
log.Warnf("content is not json, can't process: %s", contentType)
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.SetRequestBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
// Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable {
// When defaultEnable is false, we need to check if web_search_options exists in the request
if !webSearchOptions.Exists() {
log.Debugf("Plugin disabled by config and no web_search_options in request")
return types.ActionContinue
}
log.Debugf("Plugin enabled by web_search_options in request")
}
var queryIndex int
var query string
messages := gjson.GetBytes(body, "messages").Array()
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Get("role").String() == "user" {
queryIndex = i
query = messages[i].Get("content").String()
break
}
}
if query == "" {
log.Errorf("not found user query in body:%s", body)
return types.ActionContinue
}
searchRewrite := config.searchRewrite
if searchRewrite != nil {
// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
if webSearchOptions.Exists() {
searchContextSize := webSearchOptions.Get("search_context_size").String()
if searchContextSize != "" {
originalMaxCount := searchRewrite.maxCount
switch searchContextSize {
case "low":
searchRewrite.maxCount = 1
log.Debugf("Setting maxCount to 1 based on search_context_size=low")
case "medium":
searchRewrite.maxCount = 3
log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
case "high":
searchRewrite.maxCount = 5
log.Debugf("Setting maxCount to 5 based on search_context_size=high")
default:
log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
searchContextSize, searchRewrite.maxCount)
}
// If maxCount changed, regenerate the prompt from the template
if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
searchRewrite.prompt = strings.Replace(
searchRewrite.promptTemplate,
"{max_count}",
fmt.Sprintf("%d", searchRewrite.maxCount),
-1)
}
}
}
startTime := time.Now()
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
`{"stream":false,"max_tokens":4096,"model":"%s","messages":[{"role":"user","content":""}]}`,
searchRewrite.modelName)), "messages.0.content", rewritePrompt)
err := searchRewrite.client.Post(searchRewrite.url,
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", fmt.Sprintf("Bearer %s", searchRewrite.apiKey)},
}, rewriteBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
log.Errorf("search rewrite failed, status: %d", statusCode)
// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
proxywasm.ResumeHttpRequest()
return
}
content := gjson.GetBytes(responseBody, "choices.0.message.content").String()
log.Infof("LLM rewritten query response: %s (took %v), original search query:%s",
strings.ReplaceAll(content, "\n", `\n`), time.Since(startTime), query)
if strings.Contains(content, "none") {
log.Debugf("no search required")
proxywasm.ResumeHttpRequest()
return
}
// Parse search queries from LLM response
var searchContexts []engine.SearchContext
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
engineType := strings.TrimSpace(parts[0])
queryStr := strings.TrimSpace(parts[1])
var ctx engine.SearchContext
ctx.Language = config.defaultLanguage
switch {
case engineType == "internet":
ctx.EngineType = engineType
ctx.Querys = []string{queryStr}
case engineType == "private":
ctx.EngineType = engineType
ctx.Querys = strings.Split(queryStr, ",")
for i := range ctx.Querys {
ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
}
default:
// Arxiv category
ctx.EngineType = "arxiv"
ctx.ArxivCategory = engineType
ctx.Querys = strings.Split(queryStr, ",")
for i := range ctx.Querys {
ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
}
}
if len(ctx.Querys) > 0 {
searchContexts = append(searchContexts, ctx)
if ctx.ArxivCategory != "" {
// Conduct i/nquiries in all areas to increase recall.
backupCtx := ctx
backupCtx.ArxivCategory = ""
searchContexts = append(searchContexts, backupCtx)
}
}
}
if len(searchContexts) == 0 {
log.Errorf("no valid search contexts found")
proxywasm.ResumeHttpRequest()
return
}
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) {
proxywasm.ResumeHttpRequest()
}
}, searchRewrite.timeoutMillisecond)
if err != nil {
log.Errorf("search rewrite call llm service failed:%s", err)
// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
return types.ActionContinue
}
return types.ActionPause
}
// Execute search without rewrite
return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
Querys: []string{query},
Language: config.defaultLanguage,
}}, log)
}
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log wrapper.Log) types.Action {
searchResultGroups := make([][]engine.SearchResult, len(config.engine))
var finished int
var searching int
for i := 0; i < len(config.engine); i++ {
configEngine := config.engine[i]
// Check if engine needs to execute for any of the search contexts
var needsExecute bool
for _, searchCtx := range searchContexts {
if configEngine.NeedExectue(searchCtx) {
needsExecute = true
break
}
}
if !needsExecute {
continue
}
// Process all search contexts for this engine
for _, searchCtx := range searchContexts {
if !configEngine.NeedExectue(searchCtx) {
continue
}
args := configEngine.CallArgs(searchCtx)
index := i
err := configEngine.Client().Call(args.Method, args.Url, args.Headers, args.Body,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer func() {
finished++
if finished == searching {
// Merge search results from all engines with deduplication
var mergedResults []engine.SearchResult
seenLinks := make(map[string]bool)
for _, results := range searchResultGroups {
for _, result := range results {
if !seenLinks[result.Link] {
seenLinks[result.Link] = true
mergedResults = append(mergedResults, result)
}
}
}
if len(mergedResults) == 0 {
log.Warnf("no search result found, searchContexts:%#v", searchContexts)
proxywasm.ResumeHttpRequest()
return
}
// Format search results for prompt template
var formattedResults []string
var formattedReferences []string
for j, result := range mergedResults {
if config.needReference {
formattedResults = append(formattedResults,
fmt.Sprintf("[webpage %d begin]\n%s\n[webpage %d end]", j+1, result.Content, j+1))
formattedReferences = append(formattedReferences,
fmt.Sprintf("[%d] [%s](%s)", j+1, result.Title, result.Link))
} else {
formattedResults = append(formattedResults,
fmt.Sprintf("[webpage begin]\n%s\n[webpage end]", result.Content))
}
}
// Prepare template variables
curDate := time.Now().In(time.FixedZone("CST", 8*3600)).Format("2006年1月2日")
searchResults := strings.Join(formattedResults, "\n")
log.Debugf("searchResults: %s", searchResults)
// Fill prompt template
prompt := strings.Replace(config.promptTemplate, "{search_results}", searchResults, 1)
prompt = strings.Replace(prompt, "{question}", searchContexts[0].Querys[0], 1)
prompt = strings.Replace(prompt, "{cur_date}", curDate, 1)
// Update request body with processed prompt
modifiedBody, err := sjson.SetBytes(body, fmt.Sprintf("messages.%d.content", queryIndex), prompt)
if err != nil {
log.Errorf("modify request message content failed, err:%v, body:%s", err, body)
} else {
log.Debugf("modifeid body:%s", modifiedBody)
proxywasm.ReplaceHttpRequestBody(modifiedBody)
if config.needReference {
ctx.SetContext("References", strings.Join(formattedReferences, "\n\n"))
}
}
proxywasm.ResumeHttpRequest()
}
}()
if statusCode != http.StatusOK {
log.Errorf("search call failed, status: %d, engine: %#v", statusCode, configEngine)
return
}
// Append results to existing slice for this engine
searchResultGroups[index] = append(searchResultGroups[index], configEngine.ParseResult(searchCtx, responseBody)...)
}, args.TimeoutMillisecond)
if err != nil {
log.Errorf("search call failed, engine: %#v", configEngine)
continue
}
searching++
}
}
if searching > 0 {
return types.ActionPause
}
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action {
if !config.needReference {
ctx.DontReadResponseBody()
return types.ActionContinue
}
proxywasm.RemoveHttpResponseHeader("content-length")
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil {
log.Errorf("unable to load content-type header from response: %v", err)
}
ctx.BufferResponseBody()
ctx.SetResponseBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
}
return types.ActionContinue
}
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
references := ctx.GetStringContext("References", "")
if references == "" {
return types.ActionContinue
}
content := gjson.GetBytes(body, "choices.0.message.content").String()
var modifiedContent string
formattedReferences := fmt.Sprintf(config.referenceFormat, references)
if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "<think>") {
thinkEnd := strings.Index(content, "</think>")
if thinkEnd != -1 {
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = content + fmt.Sprintf("\n\n%s", formattedReferences)
} else {
// Default: add references after </think> tag
modifiedContent = content[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", formattedReferences, content[thinkEnd+8:])
}
}
}
if modifiedContent == "" {
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = fmt.Sprintf("%s\n\n%s", content, formattedReferences)
} else {
// Default: add references at the beginning
modifiedContent = fmt.Sprintf("%s\n\n%s", formattedReferences, content)
}
}
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
if err != nil {
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
return types.ActionContinue
}
proxywasm.ReplaceHttpResponseBody(body)
return types.ActionContinue
}
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
const (
PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
BUFFER_SIZE = 30
)
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
if ctx.GetBoolContext("ReferenceAppended", false) {
return chunk
}
references := ctx.GetStringContext("References", "")
if references == "" {
return chunk
}
chunk = unifySSEChunk(chunk)
var partialMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY))
if partialMessageI != nil {
partialMessage = append(partialMessageI.([]byte), chunk...)
} else {
partialMessage = chunk
}
messages := strings.Split(string(partialMessage), "\n\n")
var newMessages []string
for i, msg := range messages {
if i < len(messages)-1 {
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
if newMsg != "" {
newMessages = append(newMessages, newMsg)
}
}
}
if !strings.HasSuffix(string(partialMessage), "\n\n") {
ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1]))
} else {
ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil)
}
if len(newMessages) > 0 {
return []byte(fmt.Sprintf("%s\n\n", strings.Join(newMessages, "\n\n")))
} else {
return []byte("")
}
}
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log wrapper.Log) string {
log.Debugf("single sse message: %s", sseMessage)
subMessages := strings.Split(sseMessage, "\n")
var message string
for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") {
message = msg
break
}
}
if len(message) < 6 {
log.Errorf("[processSSEMessage] invalid message: %s", message)
return sseMessage
}
// Skip the prefix "data:"
bodyJson := message[5:]
if strings.TrimSpace(bodyJson) == "[DONE]" {
return sseMessage
}
bodyJson = strings.TrimPrefix(bodyJson, " ")
bodyJson = strings.TrimSuffix(bodyJson, "\n")
// If tailReference is true, only check if this is the last message
if tailReference {
// Check if this is the last message in the stream (finish_reason is "stop")
finishReason := gjson.Get(bodyJson, "choices.0.finish_reason").String()
if finishReason == "stop" {
// This is the last message, append references at the end
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", deltaContent+fmt.Sprintf("\n\n%s", references))
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
// Not the last message, return original message
return sseMessage
}
// Original head reference logic
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
// Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
if deltaContent == "" {
return sseMessage
}
bufferContent := ctx.GetStringContext(BUFFER_CONTENT_CONTEXT_KEY, "") + deltaContent
if len(bufferContent) < BUFFER_SIZE {
ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent)
return ""
}
if !ctx.GetBoolContext("FirstMessageChecked", false) {
ctx.SetContext("FirstMessageChecked", true)
if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "<think>") {
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
}
// Content has <think> prefix
// Check for complete </think> tag
thinkEnd := strings.Index(bufferContent, "</think>")
if thinkEnd != -1 {
modifiedContent := bufferContent[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
// Check for partial </think> tag at end of buffer
// Look for any partial match that could be completed in next message
for i := 1; i < len("</think>"); i++ {
if strings.HasSuffix(bufferContent, "</think>"[:i]) {
// Store only the partial match for the next message
ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent[len(bufferContent)-i:])
// Return the content before the partial match
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
if err != nil {
log.Errorf("update message failed:%s", err)
}
return fmt.Sprintf("data: %s", modifiedMessage)
}
}
ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, "")
return sseMessage
}