service/recall/be_x2i_recall.go (232 lines of code) (raw):

package recall import ( "encoding/json" "fmt" "strconv" "strings" "sync" be "github.com/aliyun/aliyun-be-go-sdk" "github.com/alibaba/pairec/v2/context" "github.com/alibaba/pairec/v2/datasource/beengine" "github.com/alibaba/pairec/v2/log" "github.com/alibaba/pairec/v2/module" "github.com/alibaba/pairec/v2/recconf" "github.com/alibaba/pairec/v2/service/recall/berecall" "github.com/alibaba/pairec/v2/utils" ) type BeX2IRecall struct { returnCount int bizName string recallName string scorerClause string itemIdName string triggerIdName string recallTableName string diversityParam string customParams map[string]interface{} triggerKey berecall.TriggerKey beFilterNames []string beClient *be.Client client *beengine.BeClient cloneInstances map[string]*BeX2IRecall } func NewBeX2IRecall(client *beengine.BeClient, conf recconf.BeConfig) *BeX2IRecall { if len(conf.BeRecallParams) != 1 { return nil } beClient := client.BeClient r := BeX2IRecall{ bizName: conf.BizName, returnCount: conf.BeRecallParams[0].Count, scorerClause: conf.BeRecallParams[0].ScorerClause, itemIdName: conf.BeRecallParams[0].ItemIdName, recallName: conf.BeRecallParams[0].RecallName, triggerIdName: conf.BeRecallParams[0].TriggerIdName, recallTableName: conf.BeRecallParams[0].RecallTableName, diversityParam: conf.BeRecallParams[0].DiversityParam, customParams: conf.BeRecallParams[0].CustomParams, beFilterNames: conf.BeFilterNames, triggerKey: berecall.NewTriggerKey(&conf.BeRecallParams[0], client), beClient: beClient, client: client, cloneInstances: make(map[string]*BeX2IRecall), } return &r } func (r *BeX2IRecall) buildRequest(user *module.User, context *context.RecommendContext) *be.ReadRequest { x2iReadRequest := be.NewReadRequest(r.bizName, r.returnCount) x2iReadRequest.IsRawRequest = true params := r.BuildQueryParams(user, context) params["user_id"] = string(user.Id) // trigger_list triggerKey := fmt.Sprintf("%s_list", r.recallName) triggerValues, _ := params[triggerKey] params["trigger_list"] = triggerValues delete(params, triggerKey) // return_count countKey := fmt.Sprintf("%s_return_count", r.recallName) countValue, _ := params[countKey] params["return_count"] = countValue delete(params, countKey) if len(r.beFilterNames) > 0 { if len(r.beFilterNames) == 1 { if filter, err := berecall.GetFilter(r.beFilterNames[0]); err == nil { filterParams := filter.BuildQueryParams(user, context) for k, v := range filterParams { params[k] = v } } } else { var wg sync.WaitGroup var mu sync.Mutex for _, name := range r.beFilterNames { if filter, err := berecall.GetFilter(name); err == nil { wg.Add(1) go func(filer berecall.IBeFilter) { defer wg.Done() filterParams := filter.BuildQueryParams(user, context) mu.Lock() defer mu.Unlock() for k, v := range filterParams { params[k] = v } }(filter) } } wg.Wait() } } x2iReadRequest.SetQueryParams(params) if context.Debug { uri := x2iReadRequest.BuildUri() log.Info(fmt.Sprintf("requestId=%s\tbizName=%s\turl=%s", context.RecommendId, r.bizName, uri.RequestURI())) } return x2iReadRequest } func (r *BeX2IRecall) GetItems(user *module.User, context *context.RecommendContext) (ret []*module.Item, err error) { x2iReadRequest := r.buildRequest(user, context) x2iReadResponse, err := r.beClient.Read(*x2iReadRequest) if err != nil { uri := x2iReadRequest.BuildUri() log.Error(fmt.Sprintf("requestId=%s\tbizName=%s\turl=%s", context.RecommendId, r.bizName, uri.RequestURI())) return } matchItems := x2iReadResponse.Result.MatchItems if matchItems == nil || len(matchItems.FieldValues) == 0 { return } itemIndex := -1 scoreIndex := -1 for i, name := range matchItems.FieldNames { if name == r.itemIdName { itemIndex = i } if name == "__score__" { scoreIndex = i } if itemIndex != -1 && scoreIndex != -1 { break } } if itemIndex >= 0 && scoreIndex >= 0 { var ( itemId string score float64 ) for _, values := range matchItems.FieldValues { properties := make(map[string]interface{}) for i, value := range values { if i == itemIndex { itemId = utils.ToString(value, "") } else if i == scoreIndex { score = value.(float64) } else { properties[matchItems.FieldNames[i]] = value } } item := module.NewItem(itemId) item.Score = score item.AddProperties(properties) ret = append(ret, item) } } return } func (r *BeX2IRecall) BuildQueryParams(user *module.User, context *context.RecommendContext) (ret map[string]string) { ret = make(map[string]string) triggerResult := r.triggerKey.GetTriggerKey(user, context) if triggerResult.TriggerItem == "" { return } if _, ok := r.triggerKey.(*berecall.UserRealtimeEmbeddingMindTrigger); ok { triggerItems := strings.Split(triggerResult.TriggerItem, "|") if r.client.IsProductReleased() { var items []string for i, trigger := range triggerItems { itemIdScores := strings.Split(trigger, ",") for _, item := range itemIdScores { items = append(items, fmt.Sprintf("%s:%d", item, i)) } } ret[fmt.Sprintf("%s_list", r.recallName)] = strings.Join(items, ",") } else { for i, trigger := range triggerItems { ret[fmt.Sprintf("%s_%d_list", r.recallName, i)] = trigger } } ret["mind_embedding_return_count"] = strconv.Itoa(r.returnCount) ret[fmt.Sprintf("%s_return_count", r.recallName)] = strconv.Itoa(r.returnCount) if triggerResult.DistinctParam != "" && triggerResult.DistinctParamName != "" { ret[triggerResult.DistinctParamName] = triggerResult.DistinctParam } } else if _, ok := r.triggerKey.(*berecall.UserEmbeddingDssmO2OTrigger); ok { ret[fmt.Sprintf("%s_qinfo", r.recallName)] = triggerResult.TriggerItem ret[fmt.Sprintf("%s_return_count", r.recallName)] = strconv.Itoa(r.returnCount) } else { ret[fmt.Sprintf("%s_list", r.recallName)] = triggerResult.TriggerItem ret[fmt.Sprintf("%s_return_count", r.recallName)] = strconv.Itoa(r.returnCount) } //ret[fmt.Sprintf("%s_return_count", r.recallName)] = strconv.Itoa(r.returnCount) if r.diversityParam != "" { ret[fmt.Sprintf("%s_diversity_param", r.recallName)] = r.diversityParam } else if r.triggerIdName != "" && triggerResult.DistinctParam != "" { ret[fmt.Sprintf("%s_distinct_param", r.recallName)] = fmt.Sprintf("%s:%s", r.triggerIdName, triggerResult.DistinctParam) } if r.recallTableName != "" { ret[fmt.Sprintf("%s_table", r.recallName)] = r.recallTableName } for k, v := range r.customParams { ret[k] = utils.ToString(v, "") } return } func (r *BeX2IRecall) CloneWithConfig(params map[string]interface{}) BeBaseRecall { j, err := json.Marshal(params) if err != nil { log.Error(fmt.Sprintf("event=CloneWithConfig\terror=%v", err)) return r } recallParams := recconf.BeRecallParam{} if err := json.Unmarshal(j, &recallParams); err != nil { log.Error(fmt.Sprintf("event=CloneWithConfig\terror=%v", err)) return r } d, _ := json.Marshal(recallParams) md5 := utils.Md5(string(d)) if recall, ok := r.cloneInstances[md5]; ok { return recall } recall := &BeX2IRecall{ bizName: r.bizName, beClient: r.beClient, client: r.client, beFilterNames: r.beFilterNames, returnCount: recallParams.Count, itemIdName: recallParams.ItemIdName, recallName: r.recallName, triggerIdName: recallParams.TriggerIdName, recallTableName: recallParams.RecallTableName, diversityParam: recallParams.DiversityParam, customParams: recallParams.CustomParams, triggerKey: berecall.NewTriggerKey(&recallParams, r.client), } r.cloneInstances[md5] = recall return recall }