func onHttpRequestBody()

in plugins/wasm-go/extensions/ai-search/main.go [295:452]


func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
	// Check if plugin should be enabled based on config and request
	webSearchOptions := gjson.GetBytes(body, "web_search_options")
	if !config.defaultEnable {
		// When defaultEnable is false, we need to check if web_search_options exists in the request
		if !webSearchOptions.Exists() {
			log.Debugf("Plugin disabled by config and no web_search_options in request")
			return types.ActionContinue
		}
		log.Debugf("Plugin enabled by web_search_options in request")
	}

	var queryIndex int
	var query string
	messages := gjson.GetBytes(body, "messages").Array()
	for i := len(messages) - 1; i >= 0; i-- {
		if messages[i].Get("role").String() == "user" {
			queryIndex = i
			query = messages[i].Get("content").String()
			break
		}
	}
	if query == "" {
		log.Errorf("not found user query in body:%s", body)
		return types.ActionContinue
	}
	searchRewrite := config.searchRewrite
	if searchRewrite != nil {
		// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
		if webSearchOptions.Exists() {
			searchContextSize := webSearchOptions.Get("search_context_size").String()
			if searchContextSize != "" {
				originalMaxCount := searchRewrite.maxCount
				switch searchContextSize {
				case "low":
					searchRewrite.maxCount = 1
					log.Debugf("Setting maxCount to 1 based on search_context_size=low")
				case "medium":
					searchRewrite.maxCount = 3
					log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
				case "high":
					searchRewrite.maxCount = 5
					log.Debugf("Setting maxCount to 5 based on search_context_size=high")
				default:
					log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
						searchContextSize, searchRewrite.maxCount)
				}

				// If maxCount changed, regenerate the prompt from the template
				if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
					searchRewrite.prompt = strings.Replace(
						searchRewrite.promptTemplate,
						"{max_count}",
						fmt.Sprintf("%d", searchRewrite.maxCount),
						-1)
				}
			}
		}
		startTime := time.Now()
		rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
		rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
			`{"stream":false,"max_tokens":4096,"model":"%s","messages":[{"role":"user","content":""}]}`,
			searchRewrite.modelName)), "messages.0.content", rewritePrompt)
		err := searchRewrite.client.Post(searchRewrite.url,
			[][2]string{
				{"Content-Type", "application/json"},
				{"Authorization", fmt.Sprintf("Bearer %s", searchRewrite.apiKey)},
			}, rewriteBody,
			func(statusCode int, responseHeaders http.Header, responseBody []byte) {
				if statusCode != http.StatusOK {
					log.Errorf("search rewrite failed, status: %d", statusCode)
					// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
					proxywasm.ResumeHttpRequest()
					return
				}

				content := gjson.GetBytes(responseBody, "choices.0.message.content").String()
				log.Infof("LLM rewritten query response: %s (took %v), original search query:%s",
					strings.ReplaceAll(content, "\n", `\n`), time.Since(startTime), query)
				if strings.Contains(content, "none") {
					log.Debugf("no search required")
					proxywasm.ResumeHttpRequest()
					return
				}

				// Parse search queries from LLM response
				var searchContexts []engine.SearchContext
				for _, line := range strings.Split(content, "\n") {
					line = strings.TrimSpace(line)
					if line == "" {
						continue
					}

					parts := strings.SplitN(line, ":", 2)
					if len(parts) != 2 {
						continue
					}

					engineType := strings.TrimSpace(parts[0])
					queryStr := strings.TrimSpace(parts[1])

					var ctx engine.SearchContext
					ctx.Language = config.defaultLanguage

					switch {
					case engineType == "internet":
						ctx.EngineType = engineType
						ctx.Querys = []string{queryStr}
					case engineType == "private":
						ctx.EngineType = engineType
						ctx.Querys = strings.Split(queryStr, ",")
						for i := range ctx.Querys {
							ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
						}
					default:
						// Arxiv category
						ctx.EngineType = "arxiv"
						ctx.ArxivCategory = engineType
						ctx.Querys = strings.Split(queryStr, ",")
						for i := range ctx.Querys {
							ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
						}
					}

					if len(ctx.Querys) > 0 {
						searchContexts = append(searchContexts, ctx)
						if ctx.ArxivCategory != "" {
							// Conduct i/nquiries in all areas to increase recall.
							backupCtx := ctx
							backupCtx.ArxivCategory = ""
							searchContexts = append(searchContexts, backupCtx)
						}
					}
				}

				if len(searchContexts) == 0 {
					log.Errorf("no valid search contexts found")
					proxywasm.ResumeHttpRequest()
					return
				}
				if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) {
					proxywasm.ResumeHttpRequest()
				}
			}, searchRewrite.timeoutMillisecond)
		if err != nil {
			log.Errorf("search rewrite call llm service failed:%s", err)
			// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
			return types.ActionContinue
		}
		return types.ActionPause
	}

	// Execute search without rewrite
	return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
		Querys:   []string{query},
		Language: config.defaultLanguage,
	}}, log)
}