in service/recall/hologres_vector_recall_v2.go [96:210]
func (r *HologresVectorRecallV2) GetCandidateItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) {
start := time.Now()
var userEmbedding string
userEmbKey := r.cachePrefix + string(user.Id)
if value, ok := r.userVectorCache.GetIfPresent(userEmbKey); ok {
userEmbedding = value.(string)
user.AddProperty(r.modelName+"_embedding", userEmbedding)
} else {
emb, err := r.dao.VectorString(string(user.Id))
if err != nil {
if !errors.Is(err, module.VectoryEmptyError) {
context.LogError(fmt.Sprintf("get user vector failed. %s, err=%v", r.modelName, err))
}
} else if emb != "" {
user.AddProperty(r.modelName+"_embedding", emb)
r.userVectorCache.Put(userEmbKey, emb)
userEmbedding = emb
}
}
ret = make([]*module.Item, 0, r.recallCount)
if r.cache != nil {
key := r.cachePrefix + string(user.Id)
cacheRet := r.cache.Get(key)
if itemStr, ok := cacheRet.([]uint8); ok {
itemIds := strings.Split(string(itemStr), ",")
for _, id := range itemIds {
var item *module.Item
if strings.Contains(id, ":") {
vars := strings.Split(id, ":")
item = module.NewItem(vars[0])
f, _ := strconv.ParseFloat(vars[2], 64)
// item.AddProperty(vars[1], f)
item.Score = f
} else {
item = module.NewItem(id)
}
item.RetrieveId = r.modelName
item.ItemType = r.itemType
ret = append(ret, item)
}
context.LogInfo(fmt.Sprintf("requestId=%s\tmodule=HologresVectorRecallV2\tname=%s\thit cache\tcount=%d\tcost=%d",
context.RecommendId, r.modelName, len(ret), utils.CostTime(start)))
return
}
}
if userEmbedding == "" {
return
}
if r.dbStmt == nil {
r.mu.Lock()
if r.dbStmt == nil {
r.sql = fmt.Sprintf(hologres_vector_sql_v2, r.vectorKeyField, r.vectorEmbeddingField, r.table, r.where, r.recallCount)
if context.Debug {
context.LogInfo("module=HologresVectorRecallV2\tsql=" + r.sql)
}
stmt, err := r.db.Prepare(r.sql)
if err != nil {
log.Error(fmt.Sprintf("requestId=%s\tmodule=HologresVectorRecallV2\tname=%s\terr=%v", context.RecommendId, r.modelName, err))
r.mu.Unlock()
return
}
r.dbStmt = stmt
r.mu.Unlock()
} else {
r.mu.Unlock()
}
}
rows, err := r.dbStmt.Query(userEmbedding)
if err != nil {
emb := userEmbedding
if len(userEmbedding) > 500 {
emb = userEmbedding[:500]
}
log.Error(fmt.Sprintf("requestId=%s\tmodule=HologresVectorRecallV2\tname=%s\tsql=%s\tuser_embedding=%s\terr=%v", context.RecommendId, r.modelName, r.sql, emb, err))
return
}
defer rows.Close()
for rows.Next() {
var itemId string
var distance float64
if err := rows.Scan(&itemId, &distance); err != nil {
continue
}
item := module.NewItem(itemId)
item.RetrieveId = r.modelName
item.ItemType = r.itemType
item.Score = distance
ret = append(ret, item)
}
if r.cache != nil && len(ret) > 0 {
go func() {
key := r.cachePrefix + string(user.Id)
var itemIds string
for _, item := range ret {
itemIds += fmt.Sprintf("%s:%s:%v", string(item.Id), r.modelName, item.Score) + ","
}
itemIds = itemIds[:len(itemIds)-1]
if err2 := r.cache.Put(key, itemIds, time.Duration(r.cacheTime)*time.Second); err2 != nil {
context.LogError(fmt.Sprintf("module=HologresVectorRecall\terror=%v", err2))
}
}()
}
log.Info(fmt.Sprintf("requestId=%s\tmodule=HologresVectorRecallV2\tname=%s\tcount=%d\tcost=%d",
context.RecommendId, r.modelName, len(ret), utils.CostTime(start)))
return
}