assessment/collectors/embeddings/vector_search.go (132 lines of code) (raw):
/*
Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
package assessment
import (
"context"
"encoding/json"
"fmt"
"log"
"math"
"os"
"sort"
aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/structpb"
)
const EmbeddingModel = "text-embedding-preview-0815"
type MysqlConceptDb struct {
data map[string]MySqlMigrationConcept
}
func NewMysqlConceptDb(projectId, location, language string) (*MysqlConceptDb, error) {
mysqlMigrationConcepts, err := createEmbededTextsFromFile(projectId, location, language)
if err != nil {
return nil, err
}
db := &MysqlConceptDb{data: make(map[string]MySqlMigrationConcept)}
for _, concept := range mysqlMigrationConcepts {
db.data[concept.ID] = concept
}
return db, nil
}
func NewExampleDb(filePath string) (*MysqlConceptDb, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
var records []MySqlMigrationConcept
if err := json.NewDecoder(file).Decode(&records); err != nil {
return nil, err
}
db := &MysqlConceptDb{data: make(map[string]MySqlMigrationConcept)}
for _, record := range records {
db.data[record.ID] = record
}
return db, nil
}
func cosineSimilarity(a, b []float32) float32 {
var dotProduct, normA, normB float32
for i := range a {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0
}
return dotProduct / (float32(math.Sqrt(float64(normA))) * float32(math.Sqrt(float64(normB))))
}
func embedTexts(project, location string, texts []string) ([][]float32, error) {
ctx := context.Background()
client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(location+"-aiplatform.googleapis.com:443"))
if err != nil {
return nil, err
}
defer client.Close()
endpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", project, location, EmbeddingModel)
instances := make([]*structpb.Value, len(texts))
for i, text := range texts {
instances[i] = structpb.NewStructValue(&structpb.Struct{
Fields: map[string]*structpb.Value{
"content": structpb.NewStringValue(text),
"task_type": structpb.NewStringValue("SEMANTIC_SIMILARITY"),
},
})
}
req := &aiplatformpb.PredictRequest{
Endpoint: endpoint,
Instances: instances,
}
resp, err := client.Predict(ctx, req)
if err != nil {
return nil, err
}
var embeddings [][]float32
for _, prediction := range resp.Predictions {
values := prediction.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
vector := make([]float32, len(values))
for j, value := range values {
vector[j] = float32(value.GetNumberValue())
}
embeddings = append(embeddings, vector)
}
return embeddings, nil
}
func (db *MysqlConceptDb) Search(searchTerms []string, project, location string, distance float32, topK int) map[string]map[string]interface{} {
if len(searchTerms) == 0 {
return nil
}
searchEmbeddings, err := embedTexts(project, location, searchTerms)
if err != nil {
log.Fatalf("Failed to get embeddings: %v", err)
}
targetSimilarity := 1 - distance
var results []struct {
Similarity float32
ID string
}
for _, record := range db.data {
for _, searchEmbedding := range searchEmbeddings {
similarity := cosineSimilarity(searchEmbedding, record.Embedding)
if similarity >= targetSimilarity {
results = append(results, struct {
Similarity float32
ID string
}{similarity, record.ID})
}
}
}
sort.Slice(results, func(i, j int) bool {
return results[i].Similarity > results[j].Similarity
})
output := make(map[string]map[string]interface{})
for i := 0; i < topK && i < len(results); i++ {
record := db.data[results[i].ID]
b, _ := json.MarshalIndent(record.Rewrite, "", "")
output[record.ID] = map[string]interface{}{
"distance": 1 - results[i].Similarity,
"example": record.Example,
"rewrite": string(b),
}
}
return output
}
// Sample Usage
// func main() {
// db, err := NewExampleDb("output.json")
// if err != nil {
// log.Fatalf("Failed to load database: %v", err)
// }
// searchResults := db.Search([]string{
// "How to migrate from `AUTO_INCREMENT` in PG to Spanner?",
// }, "", "", 0.25, 3)
// resultJSON, _ := json.MarshalIndent(searchResults, "", " ")
// logger.Log.Debug(string(resultJSON))
// }