assessment/assessment_engine.go (230 lines of code) (raw):

/* Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.*/ package assessment import ( "context" "fmt" "strings" "sync" assessment "github.com/GoogleCloudPlatform/spanner-migration-tool/assessment/collectors" "github.com/GoogleCloudPlatform/spanner-migration-tool/assessment/sources/mysql" "github.com/GoogleCloudPlatform/spanner-migration-tool/assessment/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/task" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "go.uber.org/zap" ) type assessmentCollectors struct { sampleCollector *assessment.SampleCollector infoSchemaCollector *assessment.InfoSchemaCollector appAssessmentCollector *assessment.MigrationCodeSummarizer } type assessmentTaskInput struct { taskName string taskFunc func(ctx context.Context, c assessmentCollectors) (utils.AssessmentOutput, error) } func PerformAssessment(conv *internal.Conv, sourceProfile profiles.SourceProfile, assessmentConfig map[string]string, projectId string) (utils.AssessmentOutput, error) { logger.Log.Info("performing assessment") logger.Log.Info(fmt.Sprintf("assessment config %+v", assessmentConfig)) logger.Log.Info(fmt.Sprintf("project id %+v", projectId)) ctx := context.Background() output := utils.AssessmentOutput{} // Initialize collectors c, err := initializeCollectors(conv, sourceProfile, assessmentConfig, projectId, ctx) if err != nil { logger.Log.Error("unable to initialize collectors") return output, err } // perform each type of assessment (in parallel) - cost, schema, app code, query, performance // within each type of assessment (in parallel) - invoke each collector to fetch information from relevant rules // Iterate over assessment rules and order output by confidence of each element. Merge outputs where required // Select the highest confidence output for each attribute // Populate assessment struct parallelTaskRunner := &task.RunParallelTasksImpl[assessmentTaskInput, utils.AssessmentOutput]{} assessmentTasksInput := []assessmentTaskInput{ { taskName: "schemaAssessment", taskFunc: func(ctx context.Context, c assessmentCollectors) (utils.AssessmentOutput, error) { result, err := performSchemaAssessment(ctx, c) return utils.AssessmentOutput{SchemaAssessment: result}, err }, }, { taskName: "appAssessment", taskFunc: func(ctx context.Context, c assessmentCollectors) (utils.AssessmentOutput, error) { result, err := performAppAssessment(ctx, c) return utils.AssessmentOutput{AppCodeAssessment: result}, err }, }, } assessmentResults, err := parallelTaskRunner.RunParallelTasks(assessmentTasksInput, 2, func(input assessmentTaskInput, mutex *sync.Mutex) task.TaskResult[utils.AssessmentOutput] { result, err := input.taskFunc(ctx, c) if err != nil { logger.Log.Error(fmt.Sprintf("could not complete %s: ", input.taskName), zap.Error(err)) } return task.TaskResult[utils.AssessmentOutput]{Result: result, Err: err} }, false) if err != nil { // Handle any error from the parallel task runner itself return output, err } for _, result := range assessmentResults { if result.Result.SchemaAssessment != nil { output.SchemaAssessment = result.Result.SchemaAssessment } if result.Result.AppCodeAssessment != nil { output.AppCodeAssessment = result.Result.AppCodeAssessment } } return output, nil } // Initilize collectors. Take a decision here on which collectors are mandatory and which are optional func initializeCollectors(conv *internal.Conv, sourceProfile profiles.SourceProfile, assessmentConfig map[string]string, projectId string, ctx context.Context) (assessmentCollectors, error) { c := assessmentCollectors{} sampleCollector, err := assessment.CreateSampleCollector() if err != nil { return c, err } c.sampleCollector = &sampleCollector infoSchemaCollector, err := assessment.CreateInfoSchemaCollector(conv, sourceProfile) if infoSchemaCollector.IsEmpty() { return c, err } c.infoSchemaCollector = &infoSchemaCollector //Initialize App Assessment Collector language, exists := assessmentConfig["language"] sourceFramework, exists := assessmentConfig["sourceFramework"] targetFramework, exists := assessmentConfig["targetFramework"] codeDirectory, exists := assessmentConfig["codeDirectory"] if exists { logger.Log.Info("initializing app collector") mysqlSchema := utils.GetDDL(conv.SrcSchema) spannerSchema := strings.Join( ddl.GetDDL( ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: "mysql"}, conv.SpSchema, conv.SpSequences), "\n") logger.Log.Debug("mysqlSchema", zap.String("schema", mysqlSchema)) logger.Log.Debug("spannerSchema", zap.String("schema", spannerSchema)) summarizer, err := assessment.NewMigrationCodeSummarizer( ctx, nil, projectId, assessmentConfig["location"], mysqlSchema, spannerSchema, codeDirectory, language, sourceFramework, targetFramework) if err != nil { logger.Log.Error("error initiating migration summarizer") return c, err } c.appAssessmentCollector = summarizer logger.Log.Info("initialized app collector") } else { logger.Log.Info("app code info unavailable") } return c, err } func performSchemaAssessment(ctx context.Context, collectors assessmentCollectors) (*utils.SchemaAssessmentOutput, error) { logger.Log.Info("starting schema assessment...") schemaOut := &utils.SchemaAssessmentOutput{} srcTableDefs, spTableDefs := collectors.infoSchemaCollector.ListTables() srcColDefs, spColDefs := collectors.infoSchemaCollector.ListColumnDefinitions() srcIndexes, spIndexes := collectors.infoSchemaCollector.ListIndexes() tableAssessments := []utils.TableAssessment{} for tableId, srcTableDef := range srcTableDefs { spTableDef := spTableDefs[tableId] tableSizeDiff := tableSizeDiffBytes(&srcTableDef, &spTableDef) columnAssessments := []utils.ColumnAssessment{} //Populate column info for id, srcColumn := range srcColDefs { spColumn := spColDefs[id] if srcColumn.TableId != tableId { //Column not of current table continue } isTypeCompatible := mysql.SourceSpecificComparisonImpl{}.IsDataTypeCodeCompatible(srcColumn, spColumn) // Make generic when more sources added sizeIncreaseInBytes := getSpColSizeBytes(spColumn) - srcColumn.MaxColumnSize colAssessment := utils.ColumnAssessment{SourceColDef: &srcColumn, SpannerColDef: &spColumn, CompatibleDataType: isTypeCompatible, SizeIncreaseInBytes: int(sizeIncreaseInBytes)} columnAssessments = append(columnAssessments, colAssessment) } //Populate indexes tableSrcIndexes := []utils.SrcIndexDetails{} for _, srcIndex := range srcIndexes { if srcIndex.TableId == tableId { tableSrcIndexes = append(tableSrcIndexes, srcIndex) } } tableSpIndexes := []utils.SpIndexDetails{} for _, spIndex := range spIndexes { if spIndex.TableId == tableId { tableSpIndexes = append(tableSpIndexes, spIndex) } } tableAssessment := utils.TableAssessment{ SourceTableDef: &srcTableDef, SpannerTableDef: &spTableDef, Columns: columnAssessments, SourceIndexDef: tableSrcIndexes, SpannerIndexDef: tableSpIndexes, CompatibleCharset: isCharsetCompatible(srcTableDef.Charset), SizeIncreaseInBytes: tableSizeDiff, } tableAssessments = append(tableAssessments, tableAssessment) } schemaOut.TableAssessmentOutput = tableAssessments schemaOut.TriggerAssessmentOutput = collectors.infoSchemaCollector.ListTriggers() schemaOut.StoredProcedureAssessmentOutput = collectors.infoSchemaCollector.ListStoredProcedures() schemaOut.FunctionAssessmentOutput = collectors.infoSchemaCollector.ListFunctions() schemaOut.ViewAssessmentOutput = collectors.infoSchemaCollector.ListViews() schemaOut.SpSequences = collectors.infoSchemaCollector.ListSpannerSequences() logger.Log.Info("schema assessment completed successfully.") return schemaOut, nil } func performAppAssessment(ctx context.Context, collectors assessmentCollectors) (*utils.AppCodeAssessmentOutput, error) { if collectors.appAssessmentCollector == nil { logger.Log.Info("not proceeding with app assessment as app collector was not initialized") return nil, nil } logger.Log.Info("starting app assessment...") codeAssessment, err := collectors.appAssessmentCollector.AnalyzeProject(ctx) if err != nil { logger.Log.Error("error analyzing project", zap.Error(err)) return nil, err } logger.Log.Debug("snippets: ", zap.Any("codeAssessment.Snippets", codeAssessment.Snippets)) logger.Log.Info("app assessment completed successfully.") return &utils.AppCodeAssessmentOutput{ Language: codeAssessment.Language, Framework: codeAssessment.Framework, TotalLoc: codeAssessment.TotalLoc, TotalFiles: codeAssessment.TotalFiles, CodeSnippets: codeAssessment.Snippets, }, nil } func isCharsetCompatible(srcCharset string) bool { if !strings.Contains(srcCharset, "utf8") { // TODO add charset level comparisons - per source return true } return false } func tableSizeDiffBytes(srcTableDef *utils.SrcTableDetails, spTableDef *utils.SpTableDetails) int { // TODO - if no spanner table exists - return nil return 1 //TODO - currently dummy implementation assuming spanner will always be bigger - to calculate based on charset and column size differences } // TODO - move to spanner interface? func getSpColSizeBytes(spCol utils.SpColumnDetails) int64 { var size int64 switch strings.ToUpper(spCol.Datatype) { case "ARRAY": return 10 * 1024 * 1024 case "BOOL": size = 1 case "BYTES": size = spCol.Len case "DATE": size = 4 case "FLOAT32": size = 4 case "FLOAT64": size = 8 case "INT64": size = 8 case "JSON": return 10 * 1024 * 1024 case "NUMERIC": size = 22 case "PROTO": size = spCol.Len case "STRING": size = spCol.Len case "STRUCT": return 10 * 1024 * 1024 case "TIMESTAMP": return 12 default: return 8 } return 8 + size //Overhead per col plus size }