assessment/collectors/embeddings/generate_embedding.go (92 lines of code) (raw):

/* Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. */ package assessment import ( "context" _ "embed" "encoding/json" "fmt" "io/ioutil" aiplatform "cloud.google.com/go/aiplatform/apiv1" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "go.uber.org/zap" "google.golang.org/api/option" "google.golang.org/protobuf/types/known/structpb" ) //go:embed go_concept_examples.json var goMysqlMigrationConcept []byte //go:embed java_concept_examples.json var javaMysqlMigrationConcept []byte type MySqlMigrationConcept struct { ID string `json:"id"` Example string `json:"example"` Rewrite struct { Theory string `json:"theory"` Options []struct { MySQLCode string `json:"mysql_code"` SpannerCode string `json:"spanner_code"` } `json:"options"` } `json:"rewrite"` Embedding []float32 `json:"embedding,omitempty"` } func createEmbededTextsFromFile(project, location, language string) ([]MySqlMigrationConcept, error) { ctx := context.Background() apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) model := "text-embedding-preview-0815" client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint)) if err != nil { return nil, err } defer client.Close() // Read the JSON file var data []byte switch language { case "go": data = goMysqlMigrationConcept case "java": data = javaMysqlMigrationConcept default: panic("Unsupported language") } var mysqlMigrationConcepts []MySqlMigrationConcept if err := json.Unmarshal(data, &mysqlMigrationConcepts); err != nil { return nil, err } instances := make([]*structpb.Value, len(mysqlMigrationConcepts)) for i, concept := range mysqlMigrationConcepts { instances[i] = structpb.NewStructValue(&structpb.Struct{ Fields: map[string]*structpb.Value{ "content": structpb.NewStringValue(concept.Example), "task_type": structpb.NewStringValue("SEMANTIC_SIMILARITY"), }, }) } req := &aiplatformpb.PredictRequest{ Endpoint: fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", project, location, model), Instances: instances, } resp, err := client.Predict(ctx, req) if err != nil { return nil, err } for i, prediction := range resp.Predictions { values := prediction.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values embeddings := make([]float32, len(values)) for j, value := range values { embeddings[j] = float32(value.GetNumberValue()) } mysqlMigrationConcepts[i].Embedding = embeddings } return mysqlMigrationConcepts, nil } func embedTextsFromFile(project, location, inputPath, outputPath string) error { mysqlMigrationConcepts, err := createEmbededTextsFromFile(project, location, "java") if err != nil { return err } // Save updated data to a new JSON file outputData, err := json.MarshalIndent(mysqlMigrationConcepts, "", " ") if err != nil { return err } if err := ioutil.WriteFile(outputPath, outputData, 0644); err != nil { return err } logger.Log.Debug("Embeddings saved to", zap.String("fkStmt", outputPath)) return nil } // Sample Usage // func main() { // if err := embedTextsFromFile("", "", "go_concept_examples.json", "output.json"); err != nil { // logger.Log.Debug("Error:", err) // } // }