service/embedding_service.go (175 lines of code) (raw):

package service import ( "encoding/json" "fmt" "sync" "time" "github.com/alibaba/pairec/v2/algorithm" "github.com/alibaba/pairec/v2/algorithm/eas" "github.com/alibaba/pairec/v2/algorithm/response" "github.com/alibaba/pairec/v2/context" "github.com/alibaba/pairec/v2/datasource/datahub" "github.com/alibaba/pairec/v2/log" plog "github.com/alibaba/pairec/v2/log" "github.com/alibaba/pairec/v2/module" "github.com/alibaba/pairec/v2/recconf" "github.com/alibaba/pairec/v2/service/feature" "github.com/alibaba/pairec/v2/service/rank" "github.com/alibaba/pairec/v2/utils" ) var ( User_Embedding_Module = "user" Item_Embedding_Module = "item" ) type EmbeddingService struct { RecommendService module string // user or item contextFeatures []string featureService *feature.FeatureService } func NewEmbeddingService() *EmbeddingService { service := EmbeddingService{ featureService: feature.DefaultFeatureService(), } return &service } func (r *EmbeddingService) Recommend(context *context.RecommendContext) ([]float32, error) { var ( user *module.User items []*module.Item ) uid := context.GetParameter("uid") if uid == nil || uid.(string) == "" { user = module.NewUserWithContext("default_user", context) } else { userId := r.GetUID(context) user = module.NewUserWithContext(userId, context) features := context.GetParameter("user_features") if features != nil { user.AddProperties(features.(map[string]any)) } r.module = User_Embedding_Module } item_id := context.GetParameter("item_id") if item_id != nil && item_id.(string) != "" { item := module.NewItem(item_id.(string)) features := context.GetParameter("item_features") if features != nil { item.AddProperties(features.(map[string]any)) } for k := range features.(map[string]any) { r.contextFeatures = append(r.contextFeatures, k) } items = append(items, item) r.module = Item_Embedding_Module } // load features items = r.featureService.LoadFeatures(user, items, context) embeddings, err := r.Rank(user, items, context) go r.recordLog(user, items, context, embeddings) return embeddings, err } func (r *EmbeddingService) Rank(user *module.User, items []*module.Item, context *context.RecommendContext) (embeddings []float32, err error) { start := time.Now() if context.Debug { data, _ := json.Marshal(user) fmt.Printf("requestId=%s\tuser=%s\n", context.RecommendId, string(data)) data, _ = json.Marshal(items) fmt.Printf("requestId=%s\titems=%s\n", context.RecommendId, string(data)) } var rankConfig recconf.RankConfig scene_name := context.GetParameter("scene").(string) embeddingConfig, ok := context.Config.EmbeddingConfs[scene_name] if ok { rankConfig = embeddingConfig.RankConf } if len(rankConfig.RankAlgoList) == 0 { return } algoGenerator := rank.CreateAlgoDataGenerator(rankConfig.Processor, r.contextFeatures) algoGenerator.SetItemFeatures(rankConfig.ItemFeatures) if r.module == User_Embedding_Module { userFeatures := user.MakeUserFeatures2() algoGenerator.AddFeatures(nil, nil, userFeatures) } else if r.module == Item_Embedding_Module { for _, item := range items { features := item.GetFeatures() algoGenerator.AddFeatures(item, features, nil) } } algoData := algoGenerator.GeneratorAlgoData() var wg sync.WaitGroup for _, algoName := range rankConfig.RankAlgoList { wg.Add(1) go func(algo string) { defer wg.Done() // run 返回原始的值,然后处理返回数据// 注册配置 ret, err := algorithm.Run(algo, algoData.GetFeatures()) if err != nil { log.Error(fmt.Sprintf("requestId=%s\terror=run algorithm error(%v)", context.RecommendId, err)) algoData.SetError(err) } else { if result, ok := ret.([]response.AlgoResponse); ok { algoData.SetAlgoResult(algo, result) } } }(algoName) } wg.Wait() if algoData.Error() != nil { return nil, algoData.Error() } for _, algoResults := range algoData.GetAlgoResult() { if len(algoResults) > 0 { if embeddingReponse, ok := algoResults[0].(*eas.TorchrecEmbeddingResponse); ok { embeddings = embeddingReponse.GetEmbedding() } } } if len(embeddings) == 0 { return nil, fmt.Errorf("embeddings is empty") } log.Info(fmt.Sprintf("requestId=%s\tmodule=EmbeddingRank\tcost=%d", context.RecommendId, utils.CostTime(start))) return } func (r *EmbeddingService) recordLog(user *module.User, items []*module.Item, context *context.RecommendContext, embeddings []float32) { scene_name := context.GetParameter("scene").(string) embeddingConfig, ok := context.Config.EmbeddingConfs[scene_name] if !ok { return } if embeddingConfig.DataSource.Name == "" || embeddingConfig.DataSource.Type == "" { return } log := make(map[string]any) log["request_id"] = context.RecommendId log["scene"] = scene_name log["request_time"] = time.Now().Unix() log["module"] = r.module if r.module == User_Embedding_Module { log["user_id"] = string(user.Id) features := user.MakeUserFeatures2() j, _ := json.Marshal(features) log["user_features"] = string(j) } else if r.module == Item_Embedding_Module { log["item_id"] = string(items[0].Id) features := items[0].GetFeatures() j, _ := json.Marshal(features) log["item_features"] = string(j) } j, _ := json.Marshal(embeddings) log["embeddings"] = string(j) var err error if embeddingConfig.DataSource.Type == recconf.DataSource_Type_Datahub { err = r.recordToDatahub(embeddingConfig.DataSource.Name, []map[string]any{log}) } if err != nil { plog.Error(fmt.Sprintf("requestId=%s\tmodule=recordLog\terror=%v", context.RecommendId, err)) } } func (r *EmbeddingService) recordToDatahub(name string, messages []map[string]interface{}) error { p, error := datahub.GetDatahub(name) if error != nil { return error } p.SendMessage(messages) return nil }