func()

in service/recall/online_hologres_vector_recall.go [105:241]


func (r *OnlineHologresVectorRecall) GetCandidateItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) {
	start := time.Now()

	var userEmbedding string
	userEmbKey := r.cachePrefix + string(user.Id)
	if value, ok := r.userVectorCache.GetIfPresent(userEmbKey); ok {
		userEmbedding = value.(string)
		//user.AddProperty(r.modelName+"_embedding", userEmbedding)
	} else {
		// get user emb from eas model
		//first get user features
		r.loadUserFeatures(user, context)
		// second invoke eas model
		algoGenerator := rank.CreateAlgoDataGenerator(r.recallAlgoType, nil)
		algoGenerator.SetItemFeatures(nil)
		algoGenerator.AddFeatures(mockItem, nil, user.MakeUserFeatures2())
		algoData := algoGenerator.GeneratorAlgoData()
		algoRet, err := algorithm.Run(r.recallAlgo, algoData.GetFeatures())
		if err != nil {
			context.LogError(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err))
		} else {
			// eas model invoke success
			if result, ok := algoRet.([]response.AlgoResponse); ok && len(result) > 0 {
				if userEmbResponse, ok := result[0].(*eas.EasyrecUserEmbResponse); ok {
					userEmbedding = userEmbResponse.GetUserEmb()
				}
			}
			//user.AddProperty(r.modelName+"_embedding", userEmbedding)
			r.userVectorCache.Put(userEmbKey, userEmbedding)
		}
	}

	ret = make([]*module.Item, 0, r.recallCount)
	if r.cache != nil {
		key := r.cachePrefix + string(user.Id)
		cacheRet := r.cache.Get(key)
		if itemStr, ok := cacheRet.([]uint8); ok {
			itemIds := strings.Split(string(itemStr), ",")
			for _, id := range itemIds {
				var item *module.Item
				if strings.Contains(id, ":") {
					vars := strings.Split(id, ":")
					item = module.NewItem(vars[0])
					f, _ := strconv.ParseFloat(vars[2], 64)
					item.Score = f
				} else {
					item = module.NewItem(id)
				}

				item.RetrieveId = r.modelName
				ret = append(ret, item)
			}
			context.LogInfo(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\thit cache\tcount=%d\tcost=%d",
				context.RecommendId, r.modelName, len(ret), utils.CostTime(start)))
			return
		}
	}

	if userEmbedding == "" {
		return
	}

	if r.dbStmt == nil {
		r.mu.Lock()
		if r.dbStmt == nil {
			if context.Debug {
				context.LogInfo(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tsql=%s", context.RecommendId, r.sql))
			}
			stmt, err := r.db.Prepare(r.sql)
			if err != nil {
				log.Error(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err))
				r.mu.Unlock()
				return
			}
			r.dbStmt = stmt
			r.mu.Unlock()
		} else {
			r.mu.Unlock()
		}
	}

	userEmbeddingList := strings.Split(userEmbedding, "|")
	var wg sync.WaitGroup
	ch := make(chan []*module.Item, len(userEmbeddingList))
	for _, userEmb := range userEmbeddingList {
		wg.Add(1)
		go func(userEmb string) {
			defer wg.Done()
			items := make([]*module.Item, 0, r.recallCount/len(userEmbeddingList))
			rows, err := r.dbStmt.Query(fmt.Sprintf("{%s}", userEmb), r.recallCount/len(userEmbeddingList))
			if err != nil {
				log.Error(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\terr=%v", context.RecommendId, r.modelName, err))
				ch <- items
				return
			}

			defer rows.Close()
			for rows.Next() {
				var itemId string
				var distance float64
				if err := rows.Scan(&itemId, &distance); err != nil {
					continue
				}

				item := module.NewItem(itemId)
				item.RetrieveId = r.modelName
				item.Score = distance

				items = append(items, item)
			}
			ch <- items

		}(userEmb)
	}
	wg.Wait()
	for i := 0; i < len(userEmbeddingList); i++ {
		items := <-ch
		ret = append(ret, items...)
	}

	if r.cache != nil && len(ret) > 0 {
		go func() {
			key := r.cachePrefix + string(user.Id)
			var itemIds string
			for _, item := range ret {
				itemIds += fmt.Sprintf("%s::%v", string(item.Id), item.Score) + ","
			}
			itemIds = itemIds[:len(itemIds)-1]
			if err2 := r.cache.Put(key, itemIds, time.Duration(r.cacheTime)*time.Second); err2 != nil {
				context.LogError(fmt.Sprintf("module=OnlineHologresVectorRecall\terror=%v", err2))
			}
		}()
	}
	log.Info(fmt.Sprintf("requestId=%s\tmodule=OnlineHologresVectorRecall\tname=%s\tcount=%d\tcost=%d",
		context.RecommendId, r.modelName, len(ret), utils.CostTime(start)))
	return
}