service/recall/opensearch_recall.go (131 lines of code) (raw):
package recall
import (
"errors"
"fmt"
"strings"
"time"
"github.com/alibaba/pairec/v2/context"
"github.com/alibaba/pairec/v2/datasource/opensearch"
"github.com/alibaba/pairec/v2/log"
"github.com/alibaba/pairec/v2/module"
"github.com/alibaba/pairec/v2/recconf"
"github.com/alibaba/pairec/v2/utils"
"github.com/alibabacloud-go/tea/tea"
)
type OpenSearchRecall struct {
*BaseRecall
openSearchClient *opensearch.OpenSearchClient
AppName string
ItemId string
RequestParams map[string]any
Params []string
}
func NewOpenSearchRecall(config recconf.RecallConfig) *OpenSearchRecall {
openSearchClient, err := opensearch.GetOpenSearchClient(config.OpenSearchConf.OpenSearchName)
if err != nil {
panic(err)
}
recall := OpenSearchRecall{
BaseRecall: NewBaseRecall(config),
openSearchClient: openSearchClient,
RequestParams: config.OpenSearchConf.RequestParams,
AppName: config.OpenSearchConf.AppName,
ItemId: config.OpenSearchConf.ItemId,
Params: config.OpenSearchConf.Params,
}
return &recall
}
func (i *OpenSearchRecall) GetCandidateItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) {
start := time.Now()
requestParams, err := i.getRequestParams(user, context)
if err != nil {
log.Error(fmt.Sprintf("requestId=%s\tevent=OpenSearchRecall\terr=%s", context.RecommendId, err.Error()))
return
}
if i.recallCount > 0 {
requestParams["query"] = fmt.Sprintf("%s&&config=start:0,hit:%d,format:fulljson", requestParams["query"], i.recallCount)
}
//log requestParams
if context.Debug {
log.Info(fmt.Sprintf("event=OpenSearchRecall\trequest_params=%v", requestParams))
}
result, err := i.openSearchClient.OpenSearchClient.Request(tea.String("GET"), tea.String("/v3/openapi/apps/"+i.AppName+"/search"), requestParams, nil, nil, i.openSearchClient.Runtime)
if err != nil {
log.Error(fmt.Sprintf("requestId=%s\tevent=OpenSearchRecall\terr=%s", context.RecommendId, err.Error()))
return
}
if result == nil {
log.Error(fmt.Sprintf("requestId=%s\tevent=OpenSearchRecall\terr=empty result", context.RecommendId))
return
}
if result.Status != "OK" {
log.Error(fmt.Sprintf("requestId=%s\tevent=OpenSearchRecall\terr=opensearch invoke error(%v)", context.RecommendId, result.Errors))
return
}
for _, osItem := range result.Result.Items {
if itemId, ok := osItem.Fields[i.ItemId]; ok {
properties := make(map[string]interface{})
for k, v := range osItem.Fields {
properties[k] = v
}
item := module.NewItemWithProperty(itemId, properties)
item.RetrieveId = i.modelName
if len(osItem.SortExprValues) > 0 {
item.Score = utils.ToFloat(osItem.SortExprValues[0], 0)
}
ret = append(ret, item)
}
}
log.Info(fmt.Sprintf("requestId=%s\tmodule=OpenSearchRecall\tname=%s\tcount=%d\tcost=%d", context.RecommendId, i.modelName, len(ret), utils.CostTime(start)))
return
}
func (i *OpenSearchRecall) getRequestParams(user *module.User, context *context.RecommendContext) (map[string]any, error) {
requestParams := make(map[string]any, len(i.RequestParams))
for k, v := range i.RequestParams {
requestParams[k] = v
}
paramResults := make([]string, len(i.Params))
for x, param := range i.Params {
paramArr := strings.Split(param, ".")
switch len(paramArr) {
case 1:
return requestParams, fmt.Errorf("Params(%s) type is error, its type should be a.b or a.b.c", param)
case 2:
if paramArr[0] == "user" {
val := user.StringProperty(paramArr[1])
paramResults[x] = val
} else if paramArr[0] == "context" {
val := utils.ToString(context.GetParameter(paramArr[1]), "")
paramResults[x] = val
}
case 3: //get value from context features
if paramArr[0] == "context" {
if paramArr[1] == "features" {
var featureMap map[string]interface{}
features := context.GetParameter("features")
if features != nil {
featureMap = features.(map[string]interface{})
val := utils.ToString(featureMap[paramArr[2]], "")
paramResults[x] = val
} else {
return requestParams, errors.New("context.feature is null")
}
} else {
return requestParams, fmt.Errorf("Params(%s) only support context.features.xxx", param)
}
} else {
return requestParams, fmt.Errorf("Params(%s) only support context.features.xxx", param)
}
default:
return requestParams, fmt.Errorf("Params(%s) type is error, its type should be a.b or a.b.c", param)
}
}
for x, result := range paramResults {
replaceStr := fmt.Sprintf("$%d", x+1) // query string 中是从 $1 开始
for k, val := range requestParams {
if str, ok := val.(string); ok {
requestParams[k] = strings.Replace(str, replaceStr, result, 1)
}
}
}
return requestParams, nil
}