in service/recall/online_hologres_vector_recall.go [105:241]
func (r *OnlineHologresVectorRecall) GetCandidateItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) {
start := time.Now()
var userEmbedding string
userEmbKey := r.cachePrefix + string(user.Id)
if value, ok := r.userVectorCache.GetIfPresent(userEmbKey); ok {
userEmbedding = value.(string)
//user.AddProperty(r.modelName+"_embedding", userEmbedding)
} else {
// get user emb from eas model
//first get user features
r.loadUserFeatures(user, context)
// second invoke eas model
algoGenerator := rank.CreateAlgoDataGenerator(r.recallAlgoType, nil)
algoGenerator.SetItemFeatures(nil)
algoGenerator.AddFeatures(mockItem, nil, user.MakeUserFeatures2())
algoData := algoGenerator.GeneratorAlgoData()
algoRet, err := algorithm.Run(r.recallAlgo, algoData.GetFeatures())
if err != nil {
context.LogError(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err))
} else {
// eas model invoke success
if result, ok := algoRet.([]response.AlgoResponse); ok && len(result) > 0 {
if userEmbResponse, ok := result[0].(*eas.EasyrecUserEmbResponse); ok {
userEmbedding = userEmbResponse.GetUserEmb()
}
}
//user.AddProperty(r.modelName+"_embedding", userEmbedding)
r.userVectorCache.Put(userEmbKey, userEmbedding)
}
}
ret = make([]*module.Item, 0, r.recallCount)
if r.cache != nil {
key := r.cachePrefix + string(user.Id)
cacheRet := r.cache.Get(key)
if itemStr, ok := cacheRet.([]uint8); ok {
itemIds := strings.Split(string(itemStr), ",")
for _, id := range itemIds {
var item *module.Item
if strings.Contains(id, ":") {
vars := strings.Split(id, ":")
item = module.NewItem(vars[0])
f, _ := strconv.ParseFloat(vars[2], 64)
item.Score = f
} else {
item = module.NewItem(id)
}
item.RetrieveId = r.modelName
ret = append(ret, item)
}
context.LogInfo(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\thit cache\tcount=%d\tcost=%d",
context.RecommendId, r.modelName, len(ret), utils.CostTime(start)))
return
}
}
if userEmbedding == "" {
return
}
if r.dbStmt == nil {
r.mu.Lock()
if r.dbStmt == nil {
if context.Debug {
context.LogInfo(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tsql=%s", context.RecommendId, r.sql))
}
stmt, err := r.db.Prepare(r.sql)
if err != nil {
log.Error(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err))
r.mu.Unlock()
return
}
r.dbStmt = stmt
r.mu.Unlock()
} else {
r.mu.Unlock()
}
}
userEmbeddingList := strings.Split(userEmbedding, "|")
var wg sync.WaitGroup
ch := make(chan []*module.Item, len(userEmbeddingList))
for _, userEmb := range userEmbeddingList {
wg.Add(1)
go func(userEmb string) {
defer wg.Done()
items := make([]*module.Item, 0, r.recallCount/len(userEmbeddingList))
rows, err := r.dbStmt.Query(fmt.Sprintf("{%s}", userEmb), r.recallCount/len(userEmbeddingList))
if err != nil {
log.Error(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err))
ch <- items
return
}
defer rows.Close()
for rows.Next() {
var itemId string
var distance float64
if err := rows.Scan(&itemId, &distance); err != nil {
continue
}
item := module.NewItem(itemId)
item.RetrieveId = r.modelName
item.Score = distance
items = append(items, item)
}
ch <- items
}(userEmb)
}
wg.Wait()
for i := 0; i < len(userEmbeddingList); i++ {
items := <-ch
ret = append(ret, items...)
}
if r.cache != nil && len(ret) > 0 {
go func() {
key := r.cachePrefix + string(user.Id)
var itemIds string
for _, item := range ret {
itemIds += fmt.Sprintf("%s::%v", string(item.Id), item.Score) + ","
}
itemIds = itemIds[:len(itemIds)-1]
if err2 := r.cache.Put(key, itemIds, time.Duration(r.cacheTime)*time.Second); err2 != nil {
context.LogError(fmt.Sprintf("module=OnlineHologresVectorRecall\terror=%v", err2))
}
}()
}
log.Info(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\tcount=%d\tcost=%d",
context.RecommendId, r.modelName, len(ret), utils.CostTime(start)))
return
}