func()

in sort/ssd_sort.go [180:296]


func (s *SSDSort) loadEmbeddingCache(ctx *context.RecommendContext, items []*module.Item) error {
	client := abtest.GetExperimentClient()
	tableSuffix := ""
	if s.suffixParam != "" && client != nil {
		scene, _ := ctx.GetParameter("scene").(string)
		tableSuffix = client.GetSceneParams(scene).GetString(s.suffixParam, "")
	}
	if tableSuffix != s.lastTableSuffixParam {
		s.mu.Lock()
		if tableSuffix != s.lastTableSuffixParam {
			s.embCache.InvalidateAll()
			s.lastTableSuffixParam = tableSuffix
		}
		s.mu.Unlock()
	}

	absentItemIds := make([]interface{}, 0)
	embedSize := 0
	lenAbsentItems := 0
	itemMap := make(map[string]*module.Item)
	for _, item := range items {
		if embI, ok := s.embCache.GetIfPresent(string(item.Id)); !ok {
			absentItemIds = append(absentItemIds, string(item.Id))
			itemMap[string(item.Id)] = item
		} else {
			item.Embedding = embI.([]float64)
			if embedSize == 0 {
				embedSize = len(item.Embedding)
			} else if embedSize != len(item.Embedding) {
				ctx.LogError(fmt.Sprintf("module=SSDSort\titem %s embedding size do not match, got %d, expect %d",
					item.Id, len(item.Embedding), embedSize))
				return errors.New("item embedding size do not match")
			}
		}
	}
	if len(absentItemIds) > 0 {
		table := s.tableName + tableSuffix
		builder := sqlbuilder.PostgreSQL.NewSelectBuilder()
		builder.Select(s.keyField, s.embeddingField)
		builder.From(table)
		builder.Where(builder.In(s.keyField, absentItemIds...))

		sqlQuery, args := builder.Build()
		ctx.LogDebug("module=SSDSort\tsqlquery=" + sqlQuery)
		rows, err := s.db.Query(sqlQuery, args...)
		if err != nil {
			ctx.LogError(fmt.Sprintf("module=SSDSort\terror=%v", err))
			return err
		}
		defer rows.Close()
		rowNum := 0
		itemID := &sql.NullString{}
		itemEmb := &sql.NullString{}
		for rows.Next() {
			if err := rows.Scan(itemID, itemEmb); err != nil {
				ctx.LogError(fmt.Sprintf("module=Scan SSDSort\terror=%v\tProductID=%s",
					err, itemID.String))
				continue
			}
			elements := strings.Split(strings.Trim(itemEmb.String, "{}"), s.embSeparator)
			vector := make([]float64, len(elements), len(elements)+1)
			for i, e := range elements {
				if val, err := strconv.ParseFloat(e, 64); err != nil {
					ctx.LogError(fmt.Sprintf("parse embedding value failed\terr=%v", err))
				} else {
					vector[i] = val
				}
			}
			if s.normalizeEmb {
				normV := floats.Norm(vector, 2)
				floats.Scale(1/normV, vector)
			}
			if s.ensurePosSimilarity {
				vector = append(vector, 1)
			}
			if embedSize == 0 {
				embedSize = len(vector)
			} else if embedSize != len(vector) {
				ctx.LogError(fmt.Sprintf("module=SSDSort\titem %s embedding size do not match, got %d, expect %d",
					itemID.String, len(vector), embedSize))
				return errors.New("item embedding size do not match")
			}
			s.embCache.Put(itemID.String, vector)
			if item, ok := itemMap[itemID.String]; ok {
				item.Embedding = vector
			} else {
				return errors.New("item id is not in map")
			}
			rowNum = rowNum + 1
		}
		lenAbsentItems = len(absentItemIds) - rowNum
		if (float64(lenAbsentItems) / float64(len(items))) > s.embMissThreshold {
			return errors.New("the number of items missing embedding is above threshold")
		}
		if lenAbsentItems > 0 {
			if embedSize == 0 {
				return errors.New("no embedding detected")
			}
			for id, item := range itemMap {
				if len(item.Embedding) == 0 {
					ctx.LogWarning(fmt.Sprintf("not find embedding of item id:%s", id))
					item.Embedding = make([]float64, 0, embedSize)
					for i := 0; i < embedSize; i++ {
						item.Embedding = append(item.Embedding, rand.NormFloat64())
					}
					normV := floats.Norm(item.Embedding, 2)
					floats.Scale(1/normV, item.Embedding)
				}
			}
		}
	}
	if ctx.Debug {
		ctx.LogDebug(fmt.Sprintf("ctx_size=%d\tlen_items=%d\tlen_absent_items=%d\tlen_emb=%d",
			ctx.Size, len(items), lenAbsentItems, embedSize))
	}
	return nil
}