assessment/utils/mysql_ddl.go (160 lines of code) (raw):

// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package utils import ( "fmt" "sort" "strconv" "strings" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" ) func PrintColumnDef(col schema.Column) string { var columnDef strings.Builder columnDef.WriteString(quote(col.Name)) columnDef.WriteString(" ") columnDef.WriteString(col.Type.Name) if len(col.Type.Mods) > 0 { columnDef.WriteString("(") for i, mod := range col.Type.Mods { columnDef.WriteString(strconv.FormatInt(mod, 10)) if i < len(col.Type.Mods)-1 { columnDef.WriteString(", ") } } columnDef.WriteString(")") } if len(col.Type.ArrayBounds) > 0 { for _, bound := range col.Type.ArrayBounds { columnDef.WriteString(fmt.Sprintf("[%d]", bound)) } } if col.NotNull { columnDef.WriteString(" NOT NULL") } if col.AutoGen.Name != "" && col.AutoGen.GenerationType == constants.AUTO_INCREMENT { columnDef.WriteString(" AUTO_INCREMENT") // Basic auto increment, adjust for others as needed. } if col.DefaultValue.IsPresent { columnDef.WriteString(" DEFAULT ") columnDef.WriteString(col.DefaultValue.Value.Statement) } return columnDef.String() } // PrintCreateTable unparses a CREATE TABLE statement. func PrintCreateTable(ct schema.Table) string { var col []string var keys []string for _, colId := range ct.ColIds { s := PrintColumnDef(ct.ColDefs[colId]) col = append(col, s) } orderedPks := []schema.Key{} orderedPks = append(orderedPks, ct.PrimaryKeys...) sort.Slice(orderedPks, func(i, j int) bool { return orderedPks[i].Order < orderedPks[j].Order }) for _, key := range orderedPks { colName := quote(ct.ColDefs[key.ColId].Name) if key.Desc { colName = colName + " DESC" } keys = append(keys, colName) } var checkString string if len(ct.CheckConstraints) > 0 { checkString = FormatCheckConstraints(ct.CheckConstraints) } else { checkString = "" } if len(keys) == 0 { return fmt.Sprintf("CREATE TABLE %s (\n%s%s);", quote(ct.Name), strings.Join(col, ", "), checkString) } return fmt.Sprintf("CREATE TABLE %s (\n%s%s, PRIMARY KEY (%s));", quote(ct.Name), strings.Join(col, ", "), checkString, strings.Join(keys, ", ")) } // PrintCreateIndex unparses a CREATE INDEX statement. func PrintCreateIndex(index schema.Index, ct schema.Table) string { var createIndex strings.Builder createIndex.WriteString("CREATE ") if index.Unique { createIndex.WriteString("UNIQUE ") } createIndex.WriteString("INDEX ") createIndex.WriteString(index.Name) createIndex.WriteString(" ON ") createIndex.WriteString(ct.Name) createIndex.WriteString(" (") // Sort keys by order sort.Slice(index.Keys, func(i, j int) bool { return index.Keys[i].Order < index.Keys[j].Order }) for i, key := range index.Keys { colName := ct.ColDefs[key.ColId].Name createIndex.WriteString(colName) if key.Desc { createIndex.WriteString(" DESC") } if i < len(index.Keys)-1 { createIndex.WriteString(", ") } } createIndex.WriteString(");") return createIndex.String() } // PrintForeignKeyAlterTable unparses the foreign keys using ALTER TABLE. func PrintForeignKeyAlterTable(fk schema.ForeignKey, tableId string, srcSchema map[string]schema.Table) string { var alterTable strings.Builder tableName := srcSchema[tableId].Name alterTable.WriteString(fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT ", quote(tableName))) if fk.Name != "" { alterTable.WriteString(fmt.Sprintf("%s ", quote(fk.Name))) } alterTable.WriteString("FOREIGN KEY (") // Add columns in the current table for i, colId := range fk.ColIds { alterTable.WriteString(fmt.Sprintf("%s", quote(srcSchema[tableId].ColDefs[colId].Name))) if i < len(fk.ColIds)-1 { alterTable.WriteString(", ") } } alterTable.WriteString(fmt.Sprintf(") REFERENCES %s (", quote(srcSchema[fk.ReferTableId].Name))) // Add referenced columns for i, refColId := range fk.ReferColumnIds { refColName := srcSchema[fk.ReferTableId].ColDefs[refColId].Name alterTable.WriteString(fmt.Sprintf("%s", quote(refColName))) if i < len(fk.ReferColumnIds)-1 { alterTable.WriteString(", ") } } alterTable.WriteString(")") // Add ON DELETE and ON UPDATE actions if fk.OnDelete != "" { alterTable.WriteString(fmt.Sprintf(" ON DELETE %s", strings.ToUpper(fk.OnDelete))) } if fk.OnUpdate != "" { alterTable.WriteString(fmt.Sprintf(" ON UPDATE %s", strings.ToUpper(fk.OnUpdate))) } alterTable.WriteString(";") return alterTable.String() } // FormatCheckConstraints formats the check constraints in SQL syntax. func FormatCheckConstraints(cks []schema.CheckConstraint) string { var builder strings.Builder for _, col := range cks { if col.Name != "" { builder.WriteString(fmt.Sprintf(", CONSTRAINT %s CHECK (%s)", quote(col.Name), col.Expr)) } else { builder.WriteString(fmt.Sprintf(", CHECK (%s)", col.Expr)) } } return builder.String() } // GetDDL returns the string representation of MySQL schema represented by schema.Table struct. func GetDDL(tableSchema map[string]schema.Table) string { var ddl []string for tableId := range tableSchema { ddl = append(ddl, PrintCreateTable(tableSchema[tableId])) for _, index := range tableSchema[tableId].Indexes { ddl = append(ddl, PrintCreateIndex(index, tableSchema[tableId])) } } // Append foreign key constraints to DDL. for t := range tableSchema { for _, fk := range tableSchema[t].ForeignKeys { ddl = append(ddl, PrintForeignKeyAlterTable(fk, tableSchema[t].Id, tableSchema)) } } return strings.Join(ddl, "\n\n") } func quote(name string) string { name = "`" + name + "`" return name }