service/recall/hologres_vector_recall.go (186 lines of code) (raw):

package recall import ( "database/sql" "errors" "fmt" "strconv" "strings" "sync" "time" "github.com/goburrow/cache" "github.com/alibaba/pairec/v2/context" "github.com/alibaba/pairec/v2/log" "github.com/alibaba/pairec/v2/module" "github.com/alibaba/pairec/v2/persist/holo" "github.com/alibaba/pairec/v2/recconf" "github.com/alibaba/pairec/v2/utils" ) var ( hologres_vector_sql = "SELECT %s, pm_approx_inner_product_distance(%s,$1) as distance FROM %s %s ORDER BY distance desc limit %d" ) type HologresVectorRecall struct { *BaseRecall db *sql.DB dao module.VectorDao table string vectorEmbeddingField string vectorKeyField string where string sql string mu sync.RWMutex dbStmt *sql.Stmt userVectorCache cache.Cache timeInterval int } func NewHologresVectorRecall(config recconf.RecallConfig) *HologresVectorRecall { hologres, err := holo.GetPostgres(config.VectorDaoConf.HologresName) if err != nil { panic(err) } recall := &HologresVectorRecall{ BaseRecall: NewBaseRecall(config), db: hologres.DB, dao: module.NewVectorDao(config), table: config.HologresVectorConf.VectorTable, vectorEmbeddingField: config.HologresVectorConf.VectorEmbeddingField, vectorKeyField: config.HologresVectorConf.VectorKeyField, where: config.HologresVectorConf.WhereClause, timeInterval: config.HologresVectorConf.TimeInterval, } createTime := time.Now().Unix() - int64(recall.timeInterval) recall.where = strings.ReplaceAll(recall.where, "${time}", strconv.FormatInt(createTime, 10)) if recall.where != "" { recall.where = "WHERE " + recall.where } if recall.cacheTime <= 0 { recall.cacheTime = 1800 } recall.userVectorCache = cache.New( cache.WithMaximumSize(10000), cache.WithExpireAfterAccess(time.Duration(recall.cacheTime+100)*time.Second), ) go func(recall *HologresVectorRecall) { partition := "{partition}" for { hologresName := config.VectorDaoConf.HologresName table := config.VectorDaoConf.PartitionInfoTable field := config.VectorDaoConf.PartitionInfoField if config.RecallType == "HologresVectorRecall" && table != "" && field != "" { newPartition := module.FetchPartition(hologresName, table, field) if newPartition != "" && newPartition != partition { recall.table = strings.Replace(recall.table, partition, newPartition, -1) partition = newPartition recall.mu.Lock() recall.dbStmt = nil recall.mu.Unlock() } time.Sleep(time.Minute) } else { break } } }(recall) recall.sql = fmt.Sprintf(hologres_vector_sql, recall.vectorKeyField, recall.vectorEmbeddingField, recall.table, recall.where, recall.recallCount) return recall } func (r *HologresVectorRecall) GetCandidateItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) { start := time.Now() var userEmbedding string userEmbKey := r.cachePrefix + string(user.Id) if value, ok := r.userVectorCache.GetIfPresent(userEmbKey); ok { userEmbedding = value.(string) user.AddProperty(r.modelName+"_embedding", userEmbedding) } else { emb, err := r.dao.VectorString(string(user.Id)) if err != nil { if !errors.Is(err, module.VectoryEmptyError) { context.LogError(fmt.Sprintf("get user vector failed. %s, err=%v", r.modelName, err)) } } else if emb != "" { user.AddProperty(r.modelName+"_embedding", emb) r.userVectorCache.Put(userEmbKey, emb) userEmbedding = emb } } ret = make([]*module.Item, 0, r.recallCount) if r.cache != nil { key := r.cachePrefix + string(user.Id) cacheRet := r.cache.Get(key) if itemStr, ok := cacheRet.([]uint8); ok { itemIds := strings.Split(string(itemStr), ",") for _, id := range itemIds { var item *module.Item if strings.Contains(id, ":") { vars := strings.Split(id, ":") item = module.NewItem(vars[0]) f, _ := strconv.ParseFloat(vars[2], 64) // item.AddProperty(vars[1], f) item.Score = f } else { item = module.NewItem(id) } item.RetrieveId = r.modelName item.ItemType = r.itemType ret = append(ret, item) } context.LogInfo(fmt.Sprintf("module=HologresVectorRecall\tname=%s\thit cache\tcount=%d\tcost=%d", r.modelName, len(ret), utils.CostTime(start))) return } } if userEmbedding == "" { return } if r.dbStmt == nil { r.mu.Lock() if r.dbStmt == nil { r.sql = fmt.Sprintf(hologres_vector_sql, r.vectorKeyField, r.vectorEmbeddingField, r.table, r.where, r.recallCount) if context.Debug { context.LogInfo("module=HologresVectorRecall\tsql=" + r.sql) } stmt, err := r.db.Prepare(r.sql) if err != nil { log.Error(fmt.Sprintf("requestId=%s\tmodule=HologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err)) r.mu.Unlock() return } r.dbStmt = stmt r.mu.Unlock() } else { r.mu.Unlock() } } rows, err := r.dbStmt.Query(userEmbedding) if err != nil { log.Error(fmt.Sprintf("requestId=%s\tmodule=HologresVectorRecall\tname=%s\tsql=%s\tuser_embedding=%s\terr=%v", context.RecommendId, r.modelName, r.sql, userEmbedding, err)) return } defer rows.Close() for rows.Next() { var itemId string var distance float64 if err := rows.Scan(&itemId, &distance); err != nil { continue } item := module.NewItem(itemId) item.RetrieveId = r.modelName item.ItemType = r.itemType item.Score = distance ret = append(ret, item) } if r.cache != nil && len(ret) > 0 { go func() { key := r.cachePrefix + string(user.Id) var itemIds string for _, item := range ret { itemIds += fmt.Sprintf("%s:%s:%v", string(item.Id), r.modelName, item.Score) + "," } itemIds = itemIds[:len(itemIds)-1] if err2 := r.cache.Put(key, itemIds, time.Duration(r.cacheTime)*time.Second); err2 != nil { context.LogError(fmt.Sprintf("module=HologresVectorRecall\terror=%v", err2)) } }() } log.Info(fmt.Sprintf("requestId=%s\tmodule=HologresVectorRecall\tname=%s\tcount=%d\tcost=%d", context.RecommendId, r.modelName, len(ret), utils.CostTime(start))) return }