in sort/ssd_sort.go [180:296]
func (s *SSDSort) loadEmbeddingCache(ctx *context.RecommendContext, items []*module.Item) error {
client := abtest.GetExperimentClient()
tableSuffix := ""
if s.suffixParam != "" && client != nil {
scene, _ := ctx.GetParameter("scene").(string)
tableSuffix = client.GetSceneParams(scene).GetString(s.suffixParam, "")
}
if tableSuffix != s.lastTableSuffixParam {
s.mu.Lock()
if tableSuffix != s.lastTableSuffixParam {
s.embCache.InvalidateAll()
s.lastTableSuffixParam = tableSuffix
}
s.mu.Unlock()
}
absentItemIds := make([]interface{}, 0)
embedSize := 0
lenAbsentItems := 0
itemMap := make(map[string]*module.Item)
for _, item := range items {
if embI, ok := s.embCache.GetIfPresent(string(item.Id)); !ok {
absentItemIds = append(absentItemIds, string(item.Id))
itemMap[string(item.Id)] = item
} else {
item.Embedding = embI.([]float64)
if embedSize == 0 {
embedSize = len(item.Embedding)
} else if embedSize != len(item.Embedding) {
ctx.LogError(fmt.Sprintf("module=SSDSort\titem %s embedding size do not match, got %d, expect %d",
item.Id, len(item.Embedding), embedSize))
return errors.New("item embedding size do not match")
}
}
}
if len(absentItemIds) > 0 {
table := s.tableName + tableSuffix
builder := sqlbuilder.PostgreSQL.NewSelectBuilder()
builder.Select(s.keyField, s.embeddingField)
builder.From(table)
builder.Where(builder.In(s.keyField, absentItemIds...))
sqlQuery, args := builder.Build()
ctx.LogDebug("module=SSDSort\tsqlquery=" + sqlQuery)
rows, err := s.db.Query(sqlQuery, args...)
if err != nil {
ctx.LogError(fmt.Sprintf("module=SSDSort\terror=%v", err))
return err
}
defer rows.Close()
rowNum := 0
itemID := &sql.NullString{}
itemEmb := &sql.NullString{}
for rows.Next() {
if err := rows.Scan(itemID, itemEmb); err != nil {
ctx.LogError(fmt.Sprintf("module=Scan SSDSort\terror=%v\tProductID=%s",
err, itemID.String))
continue
}
elements := strings.Split(strings.Trim(itemEmb.String, "{}"), s.embSeparator)
vector := make([]float64, len(elements), len(elements)+1)
for i, e := range elements {
if val, err := strconv.ParseFloat(e, 64); err != nil {
ctx.LogError(fmt.Sprintf("parse embedding value failed\terr=%v", err))
} else {
vector[i] = val
}
}
if s.normalizeEmb {
normV := floats.Norm(vector, 2)
floats.Scale(1/normV, vector)
}
if s.ensurePosSimilarity {
vector = append(vector, 1)
}
if embedSize == 0 {
embedSize = len(vector)
} else if embedSize != len(vector) {
ctx.LogError(fmt.Sprintf("module=SSDSort\titem %s embedding size do not match, got %d, expect %d",
itemID.String, len(vector), embedSize))
return errors.New("item embedding size do not match")
}
s.embCache.Put(itemID.String, vector)
if item, ok := itemMap[itemID.String]; ok {
item.Embedding = vector
} else {
return errors.New("item id is not in map")
}
rowNum = rowNum + 1
}
lenAbsentItems = len(absentItemIds) - rowNum
if (float64(lenAbsentItems) / float64(len(items))) > s.embMissThreshold {
return errors.New("the number of items missing embedding is above threshold")
}
if lenAbsentItems > 0 {
if embedSize == 0 {
return errors.New("no embedding detected")
}
for id, item := range itemMap {
if len(item.Embedding) == 0 {
ctx.LogWarning(fmt.Sprintf("not find embedding of item id:%s", id))
item.Embedding = make([]float64, 0, embedSize)
for i := 0; i < embedSize; i++ {
item.Embedding = append(item.Embedding, rand.NormFloat64())
}
normV := floats.Norm(item.Embedding, 2)
floats.Scale(1/normV, item.Embedding)
}
}
}
}
if ctx.Debug {
ctx.LogDebug(fmt.Sprintf("ctx_size=%d\tlen_items=%d\tlen_absent_items=%d\tlen_emb=%d",
ctx.Size, len(items), lenAbsentItems, embedSize))
}
return nil
}