in plugins/wasm-go/extensions/ai-search/main.go [295:452]
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
// Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable {
// When defaultEnable is false, we need to check if web_search_options exists in the request
if !webSearchOptions.Exists() {
log.Debugf("Plugin disabled by config and no web_search_options in request")
return types.ActionContinue
}
log.Debugf("Plugin enabled by web_search_options in request")
}
var queryIndex int
var query string
messages := gjson.GetBytes(body, "messages").Array()
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Get("role").String() == "user" {
queryIndex = i
query = messages[i].Get("content").String()
break
}
}
if query == "" {
log.Errorf("not found user query in body:%s", body)
return types.ActionContinue
}
searchRewrite := config.searchRewrite
if searchRewrite != nil {
// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
if webSearchOptions.Exists() {
searchContextSize := webSearchOptions.Get("search_context_size").String()
if searchContextSize != "" {
originalMaxCount := searchRewrite.maxCount
switch searchContextSize {
case "low":
searchRewrite.maxCount = 1
log.Debugf("Setting maxCount to 1 based on search_context_size=low")
case "medium":
searchRewrite.maxCount = 3
log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
case "high":
searchRewrite.maxCount = 5
log.Debugf("Setting maxCount to 5 based on search_context_size=high")
default:
log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
searchContextSize, searchRewrite.maxCount)
}
// If maxCount changed, regenerate the prompt from the template
if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
searchRewrite.prompt = strings.Replace(
searchRewrite.promptTemplate,
"{max_count}",
fmt.Sprintf("%d", searchRewrite.maxCount),
-1)
}
}
}
startTime := time.Now()
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
`{"stream":false,"max_tokens":4096,"model":"%s","messages":[{"role":"user","content":""}]}`,
searchRewrite.modelName)), "messages.0.content", rewritePrompt)
err := searchRewrite.client.Post(searchRewrite.url,
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", fmt.Sprintf("Bearer %s", searchRewrite.apiKey)},
}, rewriteBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
log.Errorf("search rewrite failed, status: %d", statusCode)
// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
proxywasm.ResumeHttpRequest()
return
}
content := gjson.GetBytes(responseBody, "choices.0.message.content").String()
log.Infof("LLM rewritten query response: %s (took %v), original search query:%s",
strings.ReplaceAll(content, "\n", `\n`), time.Since(startTime), query)
if strings.Contains(content, "none") {
log.Debugf("no search required")
proxywasm.ResumeHttpRequest()
return
}
// Parse search queries from LLM response
var searchContexts []engine.SearchContext
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
engineType := strings.TrimSpace(parts[0])
queryStr := strings.TrimSpace(parts[1])
var ctx engine.SearchContext
ctx.Language = config.defaultLanguage
switch {
case engineType == "internet":
ctx.EngineType = engineType
ctx.Querys = []string{queryStr}
case engineType == "private":
ctx.EngineType = engineType
ctx.Querys = strings.Split(queryStr, ",")
for i := range ctx.Querys {
ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
}
default:
// Arxiv category
ctx.EngineType = "arxiv"
ctx.ArxivCategory = engineType
ctx.Querys = strings.Split(queryStr, ",")
for i := range ctx.Querys {
ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
}
}
if len(ctx.Querys) > 0 {
searchContexts = append(searchContexts, ctx)
if ctx.ArxivCategory != "" {
// Conduct i/nquiries in all areas to increase recall.
backupCtx := ctx
backupCtx.ArxivCategory = ""
searchContexts = append(searchContexts, backupCtx)
}
}
}
if len(searchContexts) == 0 {
log.Errorf("no valid search contexts found")
proxywasm.ResumeHttpRequest()
return
}
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) {
proxywasm.ResumeHttpRequest()
}
}, searchRewrite.timeoutMillisecond)
if err != nil {
log.Errorf("search rewrite call llm service failed:%s", err)
// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
return types.ActionContinue
}
return types.ActionPause
}
// Execute search without rewrite
return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
Querys: []string{query},
Language: config.defaultLanguage,
}}, log)
}