sort/ssd_sort.go (468 lines of code) (raw):
package sort
import (
"database/sql"
"errors"
"fmt"
"github.com/alibaba/pairec/v2/abtest"
"github.com/alibaba/pairec/v2/context"
"github.com/alibaba/pairec/v2/log"
"github.com/alibaba/pairec/v2/module"
"github.com/alibaba/pairec/v2/persist/holo"
"github.com/alibaba/pairec/v2/recconf"
"github.com/alibaba/pairec/v2/utils"
"github.com/goburrow/cache"
"github.com/huandu/go-sqlbuilder"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
"math"
"math/rand"
gosort "sort"
"strconv"
"strings"
"sync"
"time"
)
type SSDSort struct {
db *sql.DB
tableName string
suffixParam string
keyField string
embeddingField string
embSeparator string
gamma float64
useSSDStar bool
dbStmt *sql.Stmt
mu sync.RWMutex
embCache cache.Cache
lastTableSuffixParam string
normalizeEmb bool
windowSize int
abortRunCnt int
candidateCnt int
minScorePercent float64
embMissThreshold float64
filterRetrieveIds []string
ensurePosSimilarity bool
condition *BoostScoreCondition
}
func NewSSDSort(config recconf.SSDSortConfig) *SSDSort {
hologres, err := holo.GetPostgres(config.DaoConf.HologresName)
if err != nil {
panic(err)
}
cacheTime := time.Duration(360)
if config.CacheTimeInMinutes > 0 {
cacheTime = time.Duration(config.CacheTimeInMinutes)
}
ssd := SSDSort{
db: hologres.DB,
tableName: config.TableName,
suffixParam: config.TableSuffixParam,
keyField: config.TablePKey,
embeddingField: config.EmbeddingColumn,
embSeparator: config.EmbeddingSeparator,
gamma: 0.25,
useSSDStar: config.UseSSDStar,
embCache: cache.New(cache.WithMaximumSize(500000), cache.WithExpireAfterAccess(cacheTime*time.Minute)),
lastTableSuffixParam: "",
normalizeEmb: true,
windowSize: config.WindowSize,
abortRunCnt: config.AbortRunCount,
candidateCnt: config.CandidateCount,
minScorePercent: config.MinScorePercent,
embMissThreshold: 0.5,
filterRetrieveIds: config.FilterRetrieveIds,
ensurePosSimilarity: true,
}
if config.Gamma > 0 {
ssd.gamma = config.Gamma
}
if ssd.windowSize <= 0 {
ssd.windowSize = 5
}
if ssd.embSeparator == "" {
ssd.embSeparator = ","
}
if strings.ToLower(config.NormalizeEmb) == "false" {
ssd.normalizeEmb = false
}
if strings.ToLower(config.EnsurePositiveSim) == "false" {
ssd.ensurePosSimilarity = false
}
if config.EmbMissedThreshold > 0 {
ssd.embMissThreshold = config.EmbMissedThreshold
}
if config.Condition != nil {
condition, err := NewBoostScoreCondition(config.Condition)
if err != nil {
log.Error(fmt.Sprintf("SSD Sort BoostScoreCondition error:%v", err))
} else {
ssd.condition = condition
}
}
return &ssd
}
func (s *SSDSort) Sort(sortData *SortData) error {
candidates, ok := sortData.Data.([]*module.Item)
if !ok {
return errors.New("sort data type error")
}
if len(candidates) == 0 {
return nil
}
ctx := sortData.Context
if s.condition != nil {
userProperties := sortData.User.MakeUserFeatures2()
itemProperties := make(map[string]interface{})
if flag, err := s.condition.filterParam.EvaluateByDomain(userProperties, itemProperties); err == nil && !flag {
gosort.Sort(gosort.Reverse(ItemScoreSlice(candidates)))
sortData.Data = candidates
ctx.LogInfo("module=SSDSort\tcondition eval failed, skip")
return nil
}
}
if s.abortRunCnt > 0 && len(candidates) <= s.abortRunCnt {
gosort.Sort(gosort.Reverse(ItemScoreSlice(candidates)))
sortData.Data = candidates
ctx.LogInfo(fmt.Sprintf("module=SSDSort\tcandidate cnt=%d, abort run cnt=%d", len(candidates), s.abortRunCnt))
return nil
}
params := ctx.ExperimentResult.GetExperimentParams()
names := params.Get("ssd_filter_retrieve_ids", nil)
filterRetrieveIds := make([]string, 0)
if names != nil {
if values, ok := names.([]interface{}); ok {
for _, v := range values {
if name, okay := v.(string); okay {
filterRetrieveIds = append(filterRetrieveIds, name)
}
}
}
}
if len(filterRetrieveIds) == 0 {
filterRetrieveIds = s.filterRetrieveIds
} else {
ctx.LogInfo(fmt.Sprintf("[ssd] filter retrieve ids = %v", filterRetrieveIds))
}
start := time.Now()
var result []*module.Item
if filterRetrieveIds != nil && len(filterRetrieveIds) > 0 {
backup := make([]*module.Item, 0)
selected := make([]*module.Item, 0, len(candidates))
for _, item := range candidates {
if utils.IndexOf(filterRetrieveIds, item.RetrieveId) >= 0 {
backup = append(backup, item)
} else {
selected = append(selected, item)
}
}
result = s.doSort(selected, ctx)
if len(backup) > 0 {
result = append(result, backup...)
}
} else {
result = s.doSort(candidates, ctx)
}
sortData.Data = result
ctx.LogInfo(fmt.Sprintf("module=SSDSort\tcount=%d\tcost_time=%d",
len(result), utils.CostTime(start)))
return nil
}
func (s *SSDSort) loadEmbeddingCache(ctx *context.RecommendContext, items []*module.Item) error {
client := abtest.GetExperimentClient()
tableSuffix := ""
if s.suffixParam != "" && client != nil {
scene, _ := ctx.GetParameter("scene").(string)
tableSuffix = client.GetSceneParams(scene).GetString(s.suffixParam, "")
}
if tableSuffix != s.lastTableSuffixParam {
s.mu.Lock()
if tableSuffix != s.lastTableSuffixParam {
s.embCache.InvalidateAll()
s.lastTableSuffixParam = tableSuffix
}
s.mu.Unlock()
}
absentItemIds := make([]interface{}, 0)
embedSize := 0
lenAbsentItems := 0
itemMap := make(map[string]*module.Item)
for _, item := range items {
if embI, ok := s.embCache.GetIfPresent(string(item.Id)); !ok {
absentItemIds = append(absentItemIds, string(item.Id))
itemMap[string(item.Id)] = item
} else {
item.Embedding = embI.([]float64)
if embedSize == 0 {
embedSize = len(item.Embedding)
} else if embedSize != len(item.Embedding) {
ctx.LogError(fmt.Sprintf("module=SSDSort\titem %s embedding size do not match, got %d, expect %d",
item.Id, len(item.Embedding), embedSize))
return errors.New("item embedding size do not match")
}
}
}
if len(absentItemIds) > 0 {
table := s.tableName + tableSuffix
builder := sqlbuilder.PostgreSQL.NewSelectBuilder()
builder.Select(s.keyField, s.embeddingField)
builder.From(table)
builder.Where(builder.In(s.keyField, absentItemIds...))
sqlQuery, args := builder.Build()
ctx.LogDebug("module=SSDSort\tsqlquery=" + sqlQuery)
rows, err := s.db.Query(sqlQuery, args...)
if err != nil {
ctx.LogError(fmt.Sprintf("module=SSDSort\terror=%v", err))
return err
}
defer rows.Close()
rowNum := 0
itemID := &sql.NullString{}
itemEmb := &sql.NullString{}
for rows.Next() {
if err := rows.Scan(itemID, itemEmb); err != nil {
ctx.LogError(fmt.Sprintf("module=Scan SSDSort\terror=%v\tProductID=%s",
err, itemID.String))
continue
}
elements := strings.Split(strings.Trim(itemEmb.String, "{}"), s.embSeparator)
vector := make([]float64, len(elements), len(elements)+1)
for i, e := range elements {
if val, err := strconv.ParseFloat(e, 64); err != nil {
ctx.LogError(fmt.Sprintf("parse embedding value failed\terr=%v", err))
} else {
vector[i] = val
}
}
if s.normalizeEmb {
normV := floats.Norm(vector, 2)
floats.Scale(1/normV, vector)
}
if s.ensurePosSimilarity {
vector = append(vector, 1)
}
if embedSize == 0 {
embedSize = len(vector)
} else if embedSize != len(vector) {
ctx.LogError(fmt.Sprintf("module=SSDSort\titem %s embedding size do not match, got %d, expect %d",
itemID.String, len(vector), embedSize))
return errors.New("item embedding size do not match")
}
s.embCache.Put(itemID.String, vector)
if item, ok := itemMap[itemID.String]; ok {
item.Embedding = vector
} else {
return errors.New("item id is not in map")
}
rowNum = rowNum + 1
}
lenAbsentItems = len(absentItemIds) - rowNum
if (float64(lenAbsentItems) / float64(len(items))) > s.embMissThreshold {
return errors.New("the number of items missing embedding is above threshold")
}
if lenAbsentItems > 0 {
if embedSize == 0 {
return errors.New("no embedding detected")
}
for id, item := range itemMap {
if len(item.Embedding) == 0 {
ctx.LogWarning(fmt.Sprintf("not find embedding of item id:%s", id))
item.Embedding = make([]float64, 0, embedSize)
for i := 0; i < embedSize; i++ {
item.Embedding = append(item.Embedding, rand.NormFloat64())
}
normV := floats.Norm(item.Embedding, 2)
floats.Scale(1/normV, item.Embedding)
}
}
}
}
if ctx.Debug {
ctx.LogDebug(fmt.Sprintf("ctx_size=%d\tlen_items=%d\tlen_absent_items=%d\tlen_emb=%d",
ctx.Size, len(items), lenAbsentItems, embedSize))
}
return nil
}
func (s *SSDSort) doSort(items []*module.Item, ctx *context.RecommendContext) []*module.Item {
if len(items) == 0 {
return items
}
gosort.Sort(gosort.Reverse(ItemScoreSlice(items)))
params := ctx.ExperimentResult.GetExperimentParams()
gamma := params.GetFloat("ssd_gamma", s.gamma)
if gamma == 0 {
ctx.LogDebug("ssd gamma=0, skip")
return items
}
candidateCnt := params.GetInt("ssd_candidate_count", s.candidateCnt)
minScorePercent := params.GetFloat("ssd_min_score_percent", s.minScorePercent)
if (candidateCnt > 0 || minScorePercent > 0) && len(items) > ctx.Size {
if candidateCnt > 0 {
cnt := utils.MaxInt(ctx.Size, candidateCnt)
if cnt < len(items) {
items = items[:cnt]
}
}
if minScorePercent > 0 && len(items) > ctx.Size {
idx := ctx.Size
maxScore := items[0].Score
for ; idx < len(items); idx++ {
percent := items[idx].Score / maxScore
if percent < minScorePercent {
break
}
}
items = items[:idx]
}
ctx.LogInfo(fmt.Sprintf("module=SSDSort\tcandidate count=%d", len(items)))
}
if len(s.tableName) > 0 {
if err := s.loadEmbeddingCache(ctx, items); err != nil {
ctx.LogError(fmt.Sprintf("load embedding table cache failed %v", err))
return items
}
return s.SSDWithSlidingWindow(items, ctx)
} else {
ctx.LogWarning("no embedding table and hooks")
}
return items
}
// SSDWithSlidingWindow paper: https://arxiv.org/pdf/2107.05204
func (s *SSDSort) SSDWithSlidingWindow(items []*module.Item, ctx *context.RecommendContext) []*module.Item {
defer func() {
if r := recover(); r != nil {
ctx.LogError(fmt.Sprintf("Recovered from panic in SSDWithSlidingWindow: %v", r))
}
}()
params := ctx.ExperimentResult.GetExperimentParams()
gamma := params.GetFloat("ssd_gamma", s.gamma)
windowSize := params.GetInt("ssd_window_size", s.windowSize)
if windowSize <= 1 {
ctx.LogWarning("SSD sliding window size must > 1, set to 5")
windowSize = 5
}
N := len(items)
// ensure all relevance score are positive and not in a large range
relevanceScore := make([]float64, N)
for i, item := range items {
relevanceScore[i] = item.Score
}
doNorm := params.GetInt("ssd_norm_quality_score", 0)
if doNorm == 1 {
mean, variance := stat.PopMeanVariance(relevanceScore, nil)
if mean == 0 || variance == 0 { // 模型出错时分数都是0
ctx.LogError("module=SSDSort\tall item score are zeros")
return items
}
std := math.Sqrt(variance)
for i, x := range relevanceScore {
relevanceScore[i] = stat.StdScore(x, mean, std)
items[i].AddAlgoScore("ssd_quality_score", relevanceScore[i])
}
} else if doNorm == 2 {
maxScore := relevanceScore[0]
minScore := relevanceScore[len(items)-1]
scoreSpan := maxScore - minScore
if scoreSpan == 0 { // 模型出错时分数都是0
ctx.LogError("module=SSDSort\tall item score are zeros")
return items
}
epsilon := 1e-6
for i, x := range relevanceScore {
relevanceScore[i] = ((x-minScore)/scoreSpan)*(1-epsilon) + epsilon
items[i].AddAlgoScore("ssd_quality_score", relevanceScore[i])
}
}
t := 1
idx := floats.MaxIdx(relevanceScore)
T := utils.MinInt(N, ctx.Size)
dim := len(items[idx].Embedding)
selected := make(map[int]bool, T)
selected[idx] = true
indices := make([]int, 0, T)
indices = append(indices, idx)
volume := gamma
if !s.useSSDStar {
l2norm := floats.Norm(items[idx].Embedding, 2)
if math.IsNaN(l2norm) || math.IsInf(l2norm, 0) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tinvalid embedding of item %s: %v",
items[idx].Id, items[idx].Embedding))
} else {
volume *= l2norm
}
}
B := utils.NewCycleQueue(windowSize)
P := utils.NewCycleQueue(windowSize)
for t < T {
if t > windowSize {
i := B.Pop().(int)
embI := mat.NewVecDense(dim, items[i].Embedding)
projections := P.Pop().([]float64)
for j := 0; j < N; j++ {
if _, ok := selected[j]; ok {
continue
}
scaledEmbI := mat.NewVecDense(dim, nil)
scaledEmbI.ScaleVec(projections[j], embI)
floats.Add(items[j].Embedding, scaledEmbI.RawVector().Data)
}
}
if !B.Push(idx) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tpush index %d into queue failed", idx))
} else {
ctx.LogDebug(fmt.Sprintf("module=SSDSort\tpush index %d into queue", idx))
}
projections := make([]float64, N)
embI := mat.NewVecDense(dim, items[idx].Embedding)
for j := 0; j < N; j++ {
if _, ok := selected[j]; ok {
continue
}
projections[j] = floats.Dot(items[j].Embedding, items[idx].Embedding)
projections[j] /= floats.Dot(items[idx].Embedding, items[idx].Embedding)
if math.IsNaN(projections[j]) || math.IsInf(projections[j], 0) {
projections[j] = 1.0
ctx.LogWarning(fmt.Sprintf("module=SSDSort\tinvalid projection of item %s on item %x",
items[j].Id, items[idx].Id))
}
scaledEmbI := mat.NewVecDense(dim, nil)
scaledEmbI.ScaleVec(projections[j], embI)
floats.Sub(items[j].Embedding, scaledEmbI.RawVector().Data)
}
if !P.Push(projections) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tpush projections %d into queue failed", idx))
}
t++
qualities := make([]float64, len(relevanceScore))
for i, r := range relevanceScore {
if _, ok := selected[i]; ok {
qualities[i] = -math.MaxFloat64
} else {
l2norm := floats.Norm(items[i].Embedding, 2)
if math.IsNaN(l2norm) || math.IsInf(l2norm, 0) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tinvalid embedding of item %s: %v",
items[i].Id, items[i].Embedding))
qualities[i] = r + volume*0.5
} else {
qualities[i] = r + volume*l2norm
}
}
}
idx = floats.MaxIdx(qualities)
selected[idx] = true
indices = append(indices, idx)
if !s.useSSDStar {
l2norm := floats.Norm(items[idx].Embedding, 2)
if math.IsNaN(l2norm) || math.IsInf(l2norm, 0) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tinvalid embedding of item %s: %v",
items[idx].Id, items[idx].Embedding))
} else {
volume *= l2norm
}
}
}
result := make([]*module.Item, 0, T)
for _, index := range indices {
result = append(result, items[index])
}
return result
}