service/recall/online_hologres_vector_recall.go (209 lines of code) (raw):

package recall import ( "database/sql" "fmt" "strconv" "strings" "sync" "time" "github.com/goburrow/cache" "github.com/alibaba/pairec/v2/algorithm" "github.com/alibaba/pairec/v2/algorithm/eas" "github.com/alibaba/pairec/v2/algorithm/response" "github.com/alibaba/pairec/v2/context" "github.com/alibaba/pairec/v2/log" "github.com/alibaba/pairec/v2/module" "github.com/alibaba/pairec/v2/persist/holo" "github.com/alibaba/pairec/v2/recconf" "github.com/alibaba/pairec/v2/service/feature" "github.com/alibaba/pairec/v2/service/rank" "github.com/alibaba/pairec/v2/utils" ) var ( online_hologres_vector_sql = "SELECT %s, pm_approx_inner_product_distance(%s,$1) as distance FROM %s %s ORDER BY distance desc limit $2" ) type OnlineHologresVectorRecall struct { *BaseRecall db *sql.DB table string vectorEmbeddingField string vectorKeyField string where string sql string features []*feature.Feature recallAlgoType string mu sync.RWMutex dbStmt *sql.Stmt userVectorCache cache.Cache timeInterval int } func NewOnlineHologresVectorRecall(config recconf.RecallConfig) *OnlineHologresVectorRecall { hologres, err := holo.GetPostgres(config.HologresVectorConf.HologresName) if err != nil { panic(err) } recall := &OnlineHologresVectorRecall{ BaseRecall: NewBaseRecall(config), db: hologres.DB, table: config.HologresVectorConf.VectorTable, vectorEmbeddingField: config.HologresVectorConf.VectorEmbeddingField, vectorKeyField: config.HologresVectorConf.VectorKeyField, where: config.HologresVectorConf.WhereClause, timeInterval: config.HologresVectorConf.TimeInterval, recallAlgoType: eas.Eas_Processor_EASYREC, } createTime := time.Now().Unix() - int64(recall.timeInterval) recall.where = strings.ReplaceAll(recall.where, "${time}", strconv.FormatInt(createTime, 10)) if recall.where != "" { recall.where = "WHERE " + recall.where } if recall.cacheTime <= 0 && recall.cache != nil { recall.cacheTime = 1800 } recall.userVectorCache = cache.New( cache.WithMaximumSize(10000), cache.WithExpireAfterAccess(time.Duration(recall.cacheTime+10)*time.Second), ) recall.sql = fmt.Sprintf(online_hologres_vector_sql, recall.vectorKeyField, recall.vectorEmbeddingField, recall.table, recall.where) var features []*feature.Feature for _, conf := range config.UserFeatureConfs { f := feature.LoadWithConfig(conf) features = append(features, f) } recall.features = features return recall } func (r *OnlineHologresVectorRecall) loadUserFeatures(user *module.User, context *context.RecommendContext) { var wg sync.WaitGroup for _, fea := range r.features { wg.Add(1) go func(fea *feature.Feature) { defer wg.Done() fea.LoadFeatures(user, nil, context) }(fea) } wg.Wait() } var ( mockItem = module.NewItem("mock") ) func (r *OnlineHologresVectorRecall) GetCandidateItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) { start := time.Now() var userEmbedding string userEmbKey := r.cachePrefix + string(user.Id) if value, ok := r.userVectorCache.GetIfPresent(userEmbKey); ok { userEmbedding = value.(string) //user.AddProperty(r.modelName+"_embedding", userEmbedding) } else { // get user emb from eas model //first get user features r.loadUserFeatures(user, context) // second invoke eas model algoGenerator := rank.CreateAlgoDataGenerator(r.recallAlgoType, nil) algoGenerator.SetItemFeatures(nil) algoGenerator.AddFeatures(mockItem, nil, user.MakeUserFeatures2()) algoData := algoGenerator.GeneratorAlgoData() algoRet, err := algorithm.Run(r.recallAlgo, algoData.GetFeatures()) if err != nil { context.LogError(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err)) } else { // eas model invoke success if result, ok := algoRet.([]response.AlgoResponse); ok && len(result) > 0 { if userEmbResponse, ok := result[0].(*eas.EasyrecUserEmbResponse); ok { userEmbedding = userEmbResponse.GetUserEmb() } } //user.AddProperty(r.modelName+"_embedding", userEmbedding) r.userVectorCache.Put(userEmbKey, userEmbedding) } } ret = make([]*module.Item, 0, r.recallCount) if r.cache != nil { key := r.cachePrefix + string(user.Id) cacheRet := r.cache.Get(key) if itemStr, ok := cacheRet.([]uint8); ok { itemIds := strings.Split(string(itemStr), ",") for _, id := range itemIds { var item *module.Item if strings.Contains(id, ":") { vars := strings.Split(id, ":") item = module.NewItem(vars[0]) f, _ := strconv.ParseFloat(vars[2], 64) item.Score = f } else { item = module.NewItem(id) } item.RetrieveId = r.modelName ret = append(ret, item) } context.LogInfo(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\thit cache\tcount=%d\tcost=%d", context.RecommendId, r.modelName, len(ret), utils.CostTime(start))) return } } if userEmbedding == "" { return } if r.dbStmt == nil { r.mu.Lock() if r.dbStmt == nil { if context.Debug { context.LogInfo(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tsql=%s", context.RecommendId, r.sql)) } stmt, err := r.db.Prepare(r.sql) if err != nil { log.Error(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err)) r.mu.Unlock() return } r.dbStmt = stmt r.mu.Unlock() } else { r.mu.Unlock() } } userEmbeddingList := strings.Split(userEmbedding, "|") var wg sync.WaitGroup ch := make(chan []*module.Item, len(userEmbeddingList)) for _, userEmb := range userEmbeddingList { wg.Add(1) go func(userEmb string) { defer wg.Done() items := make([]*module.Item, 0, r.recallCount/len(userEmbeddingList)) rows, err := r.dbStmt.Query(fmt.Sprintf("{%s}", userEmb), r.recallCount/len(userEmbeddingList)) if err != nil { log.Error(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err)) ch <- items return } defer rows.Close() for rows.Next() { var itemId string var distance float64 if err := rows.Scan(&itemId, &distance); err != nil { continue } item := module.NewItem(itemId) item.RetrieveId = r.modelName item.Score = distance items = append(items, item) } ch <- items }(userEmb) } wg.Wait() for i := 0; i < len(userEmbeddingList); i++ { items := <-ch ret = append(ret, items...) } if r.cache != nil && len(ret) > 0 { go func() { key := r.cachePrefix + string(user.Id) var itemIds string for _, item := range ret { itemIds += fmt.Sprintf("%s::%v", string(item.Id), item.Score) + "," } itemIds = itemIds[:len(itemIds)-1] if err2 := r.cache.Put(key, itemIds, time.Duration(r.cacheTime)*time.Second); err2 != nil { context.LogError(fmt.Sprintf("module=OnlineHologresVectorRecall\terror=%v", err2)) } }() } log.Info(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\tcount=%d\tcost=%d", context.RecommendId, r.modelName, len(ret), utils.CostTime(start))) return }