in sort/ssd_sort.go [346:486]
func (s *SSDSort) SSDWithSlidingWindow(items []*module.Item, ctx *context.RecommendContext) []*module.Item {
defer func() {
if r := recover(); r != nil {
ctx.LogError(fmt.Sprintf("Recovered from panic in SSDWithSlidingWindow: %v", r))
}
}()
params := ctx.ExperimentResult.GetExperimentParams()
gamma := params.GetFloat("ssd_gamma", s.gamma)
windowSize := params.GetInt("ssd_window_size", s.windowSize)
if windowSize <= 1 {
ctx.LogWarning("SSD sliding window size must > 1, set to 5")
windowSize = 5
}
N := len(items)
// ensure all relevance score are positive and not in a large range
relevanceScore := make([]float64, N)
for i, item := range items {
relevanceScore[i] = item.Score
}
doNorm := params.GetInt("ssd_norm_quality_score", 0)
if doNorm == 1 {
mean, variance := stat.PopMeanVariance(relevanceScore, nil)
if mean == 0 || variance == 0 { // 模型出错时分数都是0
ctx.LogError("module=SSDSort\tall item score are zeros")
return items
}
std := math.Sqrt(variance)
for i, x := range relevanceScore {
relevanceScore[i] = stat.StdScore(x, mean, std)
items[i].AddAlgoScore("ssd_quality_score", relevanceScore[i])
}
} else if doNorm == 2 {
maxScore := relevanceScore[0]
minScore := relevanceScore[len(items)-1]
scoreSpan := maxScore - minScore
if scoreSpan == 0 { // 模型出错时分数都是0
ctx.LogError("module=SSDSort\tall item score are zeros")
return items
}
epsilon := 1e-6
for i, x := range relevanceScore {
relevanceScore[i] = ((x-minScore)/scoreSpan)*(1-epsilon) + epsilon
items[i].AddAlgoScore("ssd_quality_score", relevanceScore[i])
}
}
t := 1
idx := floats.MaxIdx(relevanceScore)
T := utils.MinInt(N, ctx.Size)
dim := len(items[idx].Embedding)
selected := make(map[int]bool, T)
selected[idx] = true
indices := make([]int, 0, T)
indices = append(indices, idx)
volume := gamma
if !s.useSSDStar {
l2norm := floats.Norm(items[idx].Embedding, 2)
if math.IsNaN(l2norm) || math.IsInf(l2norm, 0) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tinvalid embedding of item %s: %v",
items[idx].Id, items[idx].Embedding))
} else {
volume *= l2norm
}
}
B := utils.NewCycleQueue(windowSize)
P := utils.NewCycleQueue(windowSize)
for t < T {
if t > windowSize {
i := B.Pop().(int)
embI := mat.NewVecDense(dim, items[i].Embedding)
projections := P.Pop().([]float64)
for j := 0; j < N; j++ {
if _, ok := selected[j]; ok {
continue
}
scaledEmbI := mat.NewVecDense(dim, nil)
scaledEmbI.ScaleVec(projections[j], embI)
floats.Add(items[j].Embedding, scaledEmbI.RawVector().Data)
}
}
if !B.Push(idx) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tpush index %d into queue failed", idx))
} else {
ctx.LogDebug(fmt.Sprintf("module=SSDSort\tpush index %d into queue", idx))
}
projections := make([]float64, N)
embI := mat.NewVecDense(dim, items[idx].Embedding)
for j := 0; j < N; j++ {
if _, ok := selected[j]; ok {
continue
}
projections[j] = floats.Dot(items[j].Embedding, items[idx].Embedding)
projections[j] /= floats.Dot(items[idx].Embedding, items[idx].Embedding)
if math.IsNaN(projections[j]) || math.IsInf(projections[j], 0) {
projections[j] = 1.0
ctx.LogWarning(fmt.Sprintf("module=SSDSort\tinvalid projection of item %s on item %x",
items[j].Id, items[idx].Id))
}
scaledEmbI := mat.NewVecDense(dim, nil)
scaledEmbI.ScaleVec(projections[j], embI)
floats.Sub(items[j].Embedding, scaledEmbI.RawVector().Data)
}
if !P.Push(projections) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tpush projections %d into queue failed", idx))
}
t++
qualities := make([]float64, len(relevanceScore))
for i, r := range relevanceScore {
if _, ok := selected[i]; ok {
qualities[i] = -math.MaxFloat64
} else {
l2norm := floats.Norm(items[i].Embedding, 2)
if math.IsNaN(l2norm) || math.IsInf(l2norm, 0) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tinvalid embedding of item %s: %v",
items[i].Id, items[i].Embedding))
qualities[i] = r + volume*0.5
} else {
qualities[i] = r + volume*l2norm
}
}
}
idx = floats.MaxIdx(qualities)
selected[idx] = true
indices = append(indices, idx)
if !s.useSSDStar {
l2norm := floats.Norm(items[idx].Embedding, 2)
if math.IsNaN(l2norm) || math.IsInf(l2norm, 0) {
ctx.LogError(fmt.Sprintf("module=SSDSort\tinvalid embedding of item %s: %v",
items[idx].Id, items[idx].Embedding))
} else {
volume *= l2norm
}
}
}
result := make([]*module.Item, 0, T)
for _, index := range indices {
result = append(result, items[index])
}
return result
}