service/recall/online_vector_recall.go (125 lines of code) (raw):

package recall import ( "fmt" "strconv" "strings" "sync" "time" "github.com/goburrow/cache" "github.com/alibaba/pairec/v2/algorithm" "github.com/alibaba/pairec/v2/algorithm/eas" "github.com/alibaba/pairec/v2/algorithm/response" "github.com/alibaba/pairec/v2/context" "github.com/alibaba/pairec/v2/log" "github.com/alibaba/pairec/v2/module" "github.com/alibaba/pairec/v2/recconf" "github.com/alibaba/pairec/v2/service/feature" "github.com/alibaba/pairec/v2/service/rank" "github.com/alibaba/pairec/v2/utils" ) const ( VectorAlgoType_EasyRec = "easyrec" VectorAlgoType_TorchRec_TDM = "torchrec_tdm" VectorAlgoType_TorchRec = "torchrec_vector" ) type OnlineVectorRecall struct { *BaseRecall features []*feature.Feature recallAlgoType string vectorAlgoType string userVectorCache cache.Cache } func NewOnlineVectorRecall(config recconf.RecallConfig) *OnlineVectorRecall { recall := &OnlineVectorRecall{ BaseRecall: NewBaseRecall(config), recallAlgoType: eas.Eas_Processor_EASYREC, vectorAlgoType: config.VectorAlgoType, } if recall.cacheTime <= 0 && recall.cache != nil { recall.cacheTime = 1800 } var features []*feature.Feature for _, conf := range config.UserFeatureConfs { f := feature.LoadWithConfig(conf) features = append(features, f) } recall.features = features return recall } func (r *OnlineVectorRecall) loadUserFeatures(user *module.User, context *context.RecommendContext) { var wg sync.WaitGroup for _, fea := range r.features { wg.Add(1) go func(fea *feature.Feature) { defer wg.Done() fea.LoadFeatures(user, nil, context) }(fea) } wg.Wait() } func (r *OnlineVectorRecall) GetCandidateItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) { start := time.Now() if r.cache != nil { key := r.cachePrefix + string(user.Id) cacheRet := r.cache.Get(key) if itemStr, ok := cacheRet.([]uint8); ok { itemIds := strings.Split(string(itemStr), ",") for _, id := range itemIds { var item *module.Item if strings.Contains(id, ":") { vars := strings.Split(id, ":") item = module.NewItem(vars[0]) f, _ := strconv.ParseFloat(vars[1], 64) item.Score = f } else { item = module.NewItem(id) } item.RetrieveId = r.modelName ret = append(ret, item) } context.LogInfo(fmt.Sprintf("module=OnlineVectorRecall\tname=%s\thit cache\tcount=%d\tcost=%d", r.modelName, len(ret), utils.CostTime(start))) return } } r.loadUserFeatures(user, context) // second invoke eas model algoGenerator := rank.CreateAlgoDataGenerator(r.recallAlgoType, nil) algoGenerator.SetItemFeatures(nil) algoGenerator.AddFeatures(nil, nil, user.MakeUserFeatures2()) algoData := algoGenerator.GeneratorAlgoData() algoRet, err := algorithm.Run(r.recallAlgo, algoData.GetFeatures()) if err != nil { context.LogError(fmt.Sprintf("requestId=%s\tmodule=OnlineVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err)) } else { // eas model invoke success if result, ok := algoRet.([]response.AlgoResponse); ok && len(result) > 0 { if r.vectorAlgoType == VectorAlgoType_TorchRec_TDM || r.vectorAlgoType == VectorAlgoType_TorchRec { if userEmbResponse, ok := result[0].(*eas.TorchrecEmbeddingItemsResponse); ok { embeddingInfos := userEmbResponse.GetEmbeddingItems() ret = make([]*module.Item, 0, len(embeddingInfos)) for _, info := range embeddingInfos { item := module.NewItem(info.ItemId) item.Score = info.Score item.RetrieveId = r.modelName ret = append(ret, item) } } } } } if r.cache != nil && len(ret) > 0 { go func() { key := r.cachePrefix + string(user.Id) var itemIds string for _, item := range ret { itemIds += fmt.Sprintf("%s:%v", string(item.Id), item.Score) + "," } itemIds = itemIds[:len(itemIds)-1] if err2 := r.cache.Put(key, itemIds, time.Duration(r.cacheTime)*time.Second); err2 != nil { context.LogError(fmt.Sprintf("requestId=%s\tmodule=OnlineVectorRecall\terror=%v", context.RecommendId, err2)) } }() } log.Info(fmt.Sprintf("requestId=%s\tmodule=OnlineVectorRecall\tname=%s\tcount=%d\tcost=%d", context.RecommendId, r.modelName, len(ret), utils.CostTime(start))) return }