sources/sqlserver/infoschema.go (387 lines of code) (raw):

// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sqlserver import ( "context" "database/sql" "fmt" "sort" "strings" sp "cloud.google.com/go/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" ) const ( uuidType string = "uniqueidentifier" geographyType string = "geography" geometryType string = "geometry" timeType string = "time" hierarchyIdType string = "hierarchyid" timestampType string = "timestamp" dateTimeType string = "datetime" dateTime2Type string = "datetime2" dateTimeOffsetType string = "datetimeoffset" smallDateTimeType string = "smalldatetime" dateType string = "date" ) type InfoSchemaImpl struct { DbName string Db *sql.DB } // GetToDdl function below implement the common.InfoSchema interface. func (isi InfoSchemaImpl) GetToDdl() common.ToDdl { return ToDdlImpl{} } // We leave the 2 functions below empty to be able to pass this as an infoSchema interface. We don't need these for now. func (isi InfoSchemaImpl) StartChangeDataCapture(ctx context.Context, conv *internal.Conv) (map[string]interface{}, error) { return nil, nil } func (isi InfoSchemaImpl) StartStreamingMigration(ctx context.Context, migrationProjectId string, client *sp.Client, conv *internal.Conv, streamingInfo map[string]interface{}) (internal.DataflowOutput, error) { return internal.DataflowOutput{}, nil } // GetTableName returns table name. func (isi InfoSchemaImpl) GetTableName(schema string, tableName string) string { if schema == "dbo" { // Drop 'dbo' prefix. return tableName } return fmt.Sprintf("%s.%s", schema, tableName) } // ProcessDataRows performs data conversion for source database // 'db'. For each table, we extract data using a "SELECT *" query, // convert the data to Spanner data (based on the source and Spanner // schemas), and write it to Spanner. If we can't get/process data // for a table, we skip that table and process the remaining tables. // // Note that the database/sql library has a somewhat complex model for // returning data from rows.Scan. Scalar values can be returned using // the native value used by the underlying driver (by passing // *interface{} to rows.Scan), or they can be converted to specific go // types. // We choose to do all type conversions explicitly ourselves so that // we can generate more targeted error messages: hence we pass // *interface{} parameters to row.Scan. func (isi InfoSchemaImpl) ProcessData(conv *internal.Conv, tableId string, srcSchema schema.Table, commonColIds []string, spSchema ddl.CreateTable, additionalAttributes internal.AdditionalDataAttributes) error { srcTableName := conv.SrcSchema[tableId].Name rowsInterface, err := isi.GetRowsFromTable(conv, tableId) if err != nil { conv.Unexpected(fmt.Sprintf("Couldn't get data for table %s : err = %s", srcTableName, err)) return err } rows := rowsInterface.(*sql.Rows) defer rows.Close() srcCols, _ := rows.Columns() v, scanArgs := buildVals(len(srcCols)) colNameIdMap := internal.GetSrcColNameIdMap(conv.SrcSchema[tableId]) for rows.Next() { // get RawBytes from data. err := rows.Scan(scanArgs...) if err != nil { conv.Unexpected(fmt.Sprintf("Couldn't process sql data row: %s", err)) // Scan failed, so we don't have any data to add to bad rows. conv.StatsAddBadRow(srcTableName, conv.DataMode()) continue } values := valsToStrings(v) newValues, err := common.PrepareValues(conv, tableId, colNameIdMap, commonColIds, srcCols, values) if err != nil { conv.Unexpected(fmt.Sprintf("Error while converting data: %s\n", err)) conv.StatsAddBadRow(srcTableName, conv.DataMode()) conv.CollectBadRow(srcTableName, srcCols, values) continue } ProcessDataRow(conv, tableId, commonColIds, srcSchema, spSchema, newValues) } return nil } // GetRowsFromTable returns a sql Rows object for a table. func (isi InfoSchemaImpl) GetRowsFromTable(conv *internal.Conv, tableId string) (interface{}, error) { tbl := conv.SrcSchema[tableId] //To get only the table name by removing the schema name prefix tblName := strings.Replace(tbl.Name, tbl.Schema+".", "", 1) q := getSelectQuery(isi.DbName, tbl.Schema, tblName, tbl.ColIds, tbl.ColDefs) rows, err := isi.Db.Query(q) if err != nil { return nil, err } return rows, err } func getSelectQuery(srcDb string, schemaName string, tableName string, colIds []string, colDefs map[string]schema.Column) string { var selects = make([]string, len(colIds)) for i, colId := range colIds { cn := colDefs[colId].Name var s string switch colDefs[colId].Type.Name { case geometryType, geographyType: s = fmt.Sprintf("[%s].STAsText() AS %s", cn, cn) case uuidType: s = fmt.Sprintf("CAST([%s] AS VARCHAR(36)) AS %s", cn, cn) case hierarchyIdType: s = fmt.Sprintf("CAST([%s] AS VARCHAR(4000)) AS %s", cn, cn) case timeType: s = fmt.Sprintf("CAST([%s] AS VARCHAR(12)) AS %s", cn, cn) case timestampType: s = fmt.Sprintf("CAST([%s] AS BIGINT) AS %s", cn, cn) case smallDateTimeType, dateTimeType, dateTime2Type, dateTimeOffsetType: s = fmt.Sprintf("CONVERT(VARCHAR(33), [%s], 126) AS %s", cn, cn) case dateType: s = fmt.Sprintf("CONVERT(VARCHAR(10), [%s], 23) AS %s", cn, cn) default: s = fmt.Sprintf("[%s]", cn) } selects[i] = s } return fmt.Sprintf("SELECT %s FROM [%s].[%s].[%s]", strings.Join(selects, ", "), srcDb, schemaName, tableName) } // buildVals contructs interface{} value containers to scan row // results into. Returns both the underlying containers (as a slice) // as well as an interface{} of pointers to containers to pass to // rows.Scan. func buildVals(n int) (v []interface{}, iv []interface{}) { v = make([]interface{}, n) for i := range v { iv = append(iv, &v[i]) } return v, iv } // GetRowCount with number of rows in each table. func (isi InfoSchemaImpl) GetRowCount(table common.SchemaAndName) (int64, error) { q := fmt.Sprintf(`SELECT COUNT(1) FROM [%s].[%s].[%s];`, isi.DbName, table.Schema, table.Name) rows, err := isi.Db.Query(q) if err != nil { return 0, err } defer rows.Close() var count int64 if rows.Next() { err := rows.Scan(&count) return count, err } return 0, nil } // GetTables return list of tables in the selected database. func (isi InfoSchemaImpl) GetTables() ([]common.SchemaAndName, error) { q := ` SELECT SCH.name AS table_schema, TBL.name AS table_name FROM sys.tables AS TBL INNER JOIN sys.schemas AS SCH ON SCH.schema_id = TBL.schema_id WHERE TBL.type = 'U' AND TBL.is_ms_shipped = 0 AND TBL.name <> 'sysdiagrams' ` rows, err := isi.Db.Query(q) if err != nil { return nil, fmt.Errorf("couldn't get tables: %w", err) } defer rows.Close() var tableSchema, tableName string var tables []common.SchemaAndName for rows.Next() { rows.Scan(&tableSchema, &tableName) tables = append(tables, common.SchemaAndName{Schema: tableSchema, Name: tableName}) } return tables, nil } // GetColumns returns a list of Column objects and names func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) { q := ` SELECT column_name, data_type, is_nullable, column_default, character_maximum_length, numeric_precision, numeric_scale FROM information_schema.COLUMNS WHERE table_schema = @p1 and table_name = @p2 ORDER BY ordinal_position; ` cols, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { return nil, nil, fmt.Errorf("couldn't get schema for table %s.%s: %s", table.Schema, table.Name, err) } defer cols.Close() colDefs := make(map[string]schema.Column) var colIds []string var colName, dataType string var isNullable string var colDefault sql.NullString // elementDataType var charMaxLen, numericPrecision, numericScale sql.NullInt64 for cols.Next() { err := cols.Scan(&colName, &dataType, &isNullable, &colDefault, &charMaxLen, &numericPrecision, &numericScale) if err != nil { conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) continue } ignored := schema.Ignored{} for _, c := range constraints[colName] { // c can be UNIQUE, PRIMARY KEY, FOREIGN KEY, // or CHECK (based on msql, sql server, postgres docs). // We've already filtered out PRIMARY KEY. switch c { case "CHECK": ignored.Check = true case "FOREIGN KEY", "PRIMARY KEY", "UNIQUE": // Nothing to do here -- these are handled elsewhere. } } ignored.Default = colDefault.Valid colId := internal.GenerateColumnId() c := schema.Column{ Id: colId, Name: colName, Type: toType(dataType, charMaxLen, numericPrecision, numericScale), NotNull: strings.ToUpper(isNullable) == "NO", Ignored: ignored, } colDefs[colId] = c colIds = append(colIds, colId) } return colDefs, colIds, nil } // GetConstraints returns a list of primary keys and by-column map of // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := ` SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA WHERE k.TABLE_SCHEMA = @p1 AND k.TABLE_NAME = @p2 ORDER BY k.ordinal_position; ` rows, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { return nil, nil, nil, err } defer rows.Close() var primaryKeys []string var col, constraint string m := make(map[string][]string) for rows.Next() { err := rows.Scan(&col, &constraint) if err != nil { conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) continue } if col == "" || constraint == "" { conv.Unexpected("Got empty col or constraint") continue } switch constraint { case "PRIMARY KEY": primaryKeys = append(primaryKeys, col) default: m[col] = append(m[col], constraint) } } return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.SchemaAndName) (foreignKeys []schema.ForeignKey, err error) { q := ` SELECT OBJECT_SCHEMA_NAME (FK.referenced_object_id) AS [schema_name], OBJECT_NAME (FK.referenced_object_id) AS [referenced_table], COL_NAME(FKC.parent_object_id, FKC.parent_column_id) AS [column], COL_NAME(FKC.referenced_object_id, FKC.referenced_column_id) AS [referenced_column], FK.name AS [foreign_key_name] FROM sys.foreign_keys AS FK INNER JOIN sys.foreign_key_columns AS FKC ON FK.object_id = FKC.constraint_object_id WHERE FK.parent_object_id = OBJECT_ID(@p1); ` rows, err := isi.Db.Query(q, fmt.Sprintf("%s.%s", table.Schema, table.Name)) if err != nil { return nil, err } defer rows.Close() var refTable common.SchemaAndName var col, refCol, fKeyName string fKeys := make(map[string]common.FkConstraint) var keyNames []string for rows.Next() { err := rows.Scan(&refTable.Schema, &refTable.Name, &col, &refCol, &fKeyName) if err != nil { conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) continue } tableName := isi.GetTableName(refTable.Schema, refTable.Name) if _, found := fKeys[fKeyName]; found { fk := fKeys[fKeyName] fk.Cols = append(fk.Cols, col) fk.Refcols = append(fk.Refcols, refCol) fKeys[fKeyName] = fk continue } fKeys[fKeyName] = common.FkConstraint{Name: fKeyName, Table: tableName, Refcols: []string{refCol}, Cols: []string{col}} keyNames = append(keyNames, fKeyName) } sort.Strings(keyNames) for _, k := range keyNames { foreignKeys = append(foreignKeys, schema.ForeignKey{ Id: internal.GenerateForeignkeyId(), Name: fKeys[k].Name, ColumnNames: fKeys[k].Cols, ReferTableName: fKeys[k].Table, ReferColumnNames: fKeys[k].Refcols}) } return foreignKeys, nil } // GetIndexes return a list of all indexes for the specified table. func (isi InfoSchemaImpl) GetIndexes(conv *internal.Conv, table common.SchemaAndName, colNameIdMap map[string]string) ([]schema.Index, error) { q2 := ` SELECT IX.name, COL_NAME(IX.object_id, IXC.column_id) as [Column Name], IX.is_unique, IXC.is_descending_key, IXC.is_included_column FROM sys.indexes IX INNER JOIN sys.index_columns IXC ON IX.object_id = IXC.object_id AND IX.index_id = IXC.index_id INNER JOIN sys.tables TAB ON IX.object_id = TAB.object_id WHERE IX.is_primary_key = 0 AND IX.is_unique_constraint = 0 AND TAB.is_ms_shipped = 0 AND TAB.name=@p1 AND TAB.schema_id = SCHEMA_ID(@p2) AND IX.type != 5 -- type=5 for clustered columnstore indexes ORDER BY IX.name ; ` rows, err := isi.Db.Query(q2, table.Name, table.Schema) if err != nil { return nil, err } defer rows.Close() var name, column, isUnique, collation, isStored string indexMap := make(map[string]schema.Index) var indexNames []string var indexes []schema.Index for rows.Next() { if err := rows.Scan(&name, &column, &isUnique, &collation, &isStored); err != nil { conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) continue } if _, found := indexMap[name]; !found { indexNames = append(indexNames, name) indexMap[name] = schema.Index{ Id: internal.GenerateIndexesId(), Name: name, Unique: (isUnique == "true")} } index := indexMap[name] if isStored == "false" { index.Keys = append(index.Keys, schema.Key{ ColId: colNameIdMap[column], Desc: (collation == "DESC")}) } else { index.StoredColumnIds = append(index.StoredColumnIds, colNameIdMap[column]) } indexMap[name] = index } for _, k := range indexNames { indexes = append(indexes, indexMap[k]) } return indexes, nil } func toType(dataType string, charLen sql.NullInt64, numericPrecision, numericScale sql.NullInt64) schema.Type { switch { case charLen.Valid: return schema.Type{Name: dataType, Mods: []int64{charLen.Int64}} case numericPrecision.Valid && numericScale.Valid && numericScale.Int64 != 0: return schema.Type{Name: dataType, Mods: []int64{numericPrecision.Int64, numericScale.Int64}} case numericPrecision.Valid: return schema.Type{Name: dataType, Mods: []int64{numericPrecision.Int64}} default: return schema.Type{Name: dataType} } } func valsToStrings(vals []interface{}) []string { toString := func(val interface{}) string { if val == nil { return "NULL" } switch v := val.(type) { case []uint8: val = string([]byte(v)) case *interface{}: val = *v } return fmt.Sprintf("%v", val) } var s []string for _, v := range vals { s = append(s, toString(v)) } return s }