sources/mysql/infoschema.go (483 lines of code) (raw):
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"context"
"database/sql"
"fmt"
"regexp"
"sort"
"strings"
sp "cloud.google.com/go/spanner"
_ "github.com/go-sql-driver/mysql" // The driver should be used via the database/sql package.
_ "github.com/lib/pq"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/profiles"
"github.com/GoogleCloudPlatform/spanner-migration-tool/schema"
"github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common"
"github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl"
"github.com/GoogleCloudPlatform/spanner-migration-tool/streaming"
)
var collationRegex = regexp.MustCompile(constants.DB_COLLATION_REGEX)
// InfoSchemaImpl is MySQL specific implementation for InfoSchema.
type InfoSchemaImpl struct {
DbName string
Db *sql.DB
MigrationProjectId string
SourceProfile profiles.SourceProfile
TargetProfile profiles.TargetProfile
}
// GetToDdl implement the common.InfoSchema interface.
func (isi InfoSchemaImpl) GetToDdl() common.ToDdl {
return ToDdlImpl{}
}
// GetTableName returns table name.
func (isi InfoSchemaImpl) GetTableName(dbName string, tableName string) string {
return tableName
}
// GetRowsFromTable returns a sql Rows object for a table.
func (isi InfoSchemaImpl) GetRowsFromTable(conv *internal.Conv, tableId string) (interface{}, error) {
srcSchema := conv.SrcSchema[tableId]
srcCols := []string{}
for _, srcColId := range srcSchema.ColIds {
srcCols = append(srcCols, conv.SrcSchema[tableId].ColDefs[srcColId].Name)
}
if len(srcCols) == 0 {
conv.Unexpected(fmt.Sprintf("Couldn't get source columns for table %s ", srcSchema.Name))
return nil, nil
}
// MySQL schema and name can be arbitrary strings.
// Ideally we would pass schema/name as a query parameter,
// but MySQL doesn't support this. So we quote it instead.
colNameList := buildColNameList(srcSchema, srcCols)
q := fmt.Sprintf("SELECT %s FROM `%s`.`%s`;", colNameList, isi.DbName, srcSchema.Name)
rows, err := isi.Db.Query(q)
return rows, err
}
// Building list of column names to support mysql spatial datatypes instead of
// using 'SELECT *' because spatial columns will be fetched using ST_AsText(colName).
func buildColNameList(srcSchema schema.Table, srcColName []string) string {
var srcColTypes []string
var colList, colTmpName string
for _, colName := range srcColName {
// To handle cases where column name is reserved keyword or having space between words.
colTmpName = "`" + colName + "`"
srcColTypes = append(srcColTypes, srcSchema.ColDefs[colName].Type.Name)
for _, spatial := range MysqlSpatialDataTypes {
if strings.Contains(strings.ToLower(srcSchema.ColDefs[colName].Type.Name), spatial) {
colTmpName = "ST_AsText" + "(" + colTmpName + ")" + colTmpName
break
}
}
colList = colList + colTmpName + ","
}
return colList[:len(colList)-1]
}
// ProcessData performs data conversion for source database.
func (isi InfoSchemaImpl) ProcessData(conv *internal.Conv, tableId string, srcSchema schema.Table, commonColIds []string, spSchema ddl.CreateTable, additionalAttributes internal.AdditionalDataAttributes) error {
srcTableName := conv.SrcSchema[tableId].Name
rowsInterface, err := isi.GetRowsFromTable(conv, tableId)
if err != nil {
conv.Unexpected(fmt.Sprintf("Couldn't get data for table %s : err = %s", srcTableName, err))
return err
}
rows := rowsInterface.(*sql.Rows)
defer rows.Close()
srcCols, _ := rows.Columns()
v, scanArgs := buildVals(len(srcCols))
colNameIdMap := internal.GetSrcColNameIdMap(conv.SrcSchema[tableId])
for rows.Next() {
// get RawBytes from data.
err := rows.Scan(scanArgs...)
if err != nil {
conv.Unexpected(fmt.Sprintf("Couldn't process sql data row: %s", err))
// Scan failed, so we don't have any data to add to bad rows.
conv.StatsAddBadRow(srcTableName, conv.DataMode())
continue
}
values := valsToStrings(v)
newValues, err := common.PrepareValues(conv, tableId, colNameIdMap, commonColIds, srcCols, values)
if err != nil {
conv.Unexpected(fmt.Sprintf("Error while converting data: %s\n", err))
conv.StatsAddBadRow(srcTableName, conv.DataMode())
conv.CollectBadRow(srcTableName, srcCols, values)
continue
}
ProcessDataRow(conv, tableId, commonColIds, srcSchema, spSchema, newValues, additionalAttributes)
}
return nil
}
// GetRowCount with number of rows in each table.
func (isi InfoSchemaImpl) GetRowCount(table common.SchemaAndName) (int64, error) {
// MySQL schema and name can be arbitrary strings.
// Ideally we would pass schema/name as a query parameter,
// but MySQL doesn't support this. So we quote it instead.
q := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`;", table.Schema, table.Name)
rows, err := isi.Db.Query(q)
if err != nil {
return 0, err
}
defer rows.Close()
var count int64
if rows.Next() {
err := rows.Scan(&count)
return count, err
}
return 0, nil // Check if 0 is ok to return
}
// GetTables return list of tables in the selected database.
// Note that sql.DB already effectively has the dbName
// embedded within it (dbName is part of the DSN passed to sql.Open),
// but unfortunately there is no way to extract it from sql.DB.
func (isi InfoSchemaImpl) GetTables() ([]common.SchemaAndName, error) {
// In MySQL, schema is the same as database name.
q := "SELECT table_name FROM information_schema.tables where table_type = 'BASE TABLE' and table_schema=?"
rows, err := isi.Db.Query(q, isi.DbName)
if err != nil {
return nil, fmt.Errorf("couldn't get tables: %w", err)
}
defer rows.Close()
var tableName string
var tables []common.SchemaAndName
for rows.Next() {
rows.Scan(&tableName)
tables = append(tables, common.SchemaAndName{Schema: isi.DbName, Name: tableName})
}
return tables, nil
}
// GetColumns returns a list of Column objects and names// ProcessColumns
func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) {
q := `SELECT c.column_name, c.data_type, c.column_type, c.is_nullable, c.column_default, c.character_maximum_length, c.numeric_precision, c.numeric_scale, c.extra
FROM information_schema.COLUMNS c
where table_schema = ? and table_name = ? ORDER BY c.ordinal_position;`
cols, err := isi.Db.Query(q, table.Schema, table.Name)
if err != nil {
return nil, nil, fmt.Errorf("couldn't get schema for table %s.%s: %s", table.Schema, table.Name, err)
}
defer cols.Close()
colDefs := make(map[string]schema.Column)
var colIds []string
var colName, dataType, isNullable, columnType string
var colDefault, colExtra sql.NullString
var charMaxLen, numericPrecision, numericScale sql.NullInt64
var colAutoGen ddl.AutoGenCol
for cols.Next() {
err := cols.Scan(&colName, &dataType, &columnType, &isNullable, &colDefault, &charMaxLen, &numericPrecision, &numericScale, &colExtra)
if err != nil {
conv.Unexpected(fmt.Sprintf("Can't scan: %v", err))
continue
}
ignored := schema.Ignored{}
ignored.Default = colDefault.Valid
colId := internal.GenerateColumnId()
if colExtra.String == "auto_increment" {
sequence := createSequence(conv)
colAutoGen = ddl.AutoGenCol{
Name: sequence.Name,
GenerationType: constants.AUTO_INCREMENT,
}
sequence.ColumnsUsingSeq = map[string][]string{
table.Id: {colId},
}
conv.SrcSequences[sequence.Id] = sequence
} else {
colAutoGen = ddl.AutoGenCol{}
}
defaultVal := ddl.DefaultValue{
IsPresent: colDefault.Valid,
Value: ddl.Expression{},
}
if colDefault.Valid {
ty := dataType
if conv.SpDialect == constants.DIALECT_POSTGRESQL {
ty = ddl.GetPGType(ddl.Type{Name: ty})
}
defaultVal.Value = ddl.Expression{
ExpressionId: internal.GenerateExpressionId(),
Statement: common.SanitizeDefaultValue(colDefault.String, ty, colExtra.String == constants.DEFAULT_GENERATED),
}
}
c := schema.Column{
Id: colId,
Name: colName,
Type: toType(dataType, columnType, charMaxLen, numericPrecision, numericScale),
NotNull: common.ToNotNull(conv, isNullable),
Ignored: ignored,
AutoGen: colAutoGen,
DefaultValue: defaultVal,
}
colDefs[colId] = c
colIds = append(colIds, colId)
}
return colDefs, colIds, nil
}
// GetConstraints returns a list of primary keys and by-column map of
// other constraints. Note: we need to preserve ordinal order of
// columns in primary key constraints.
// Note that foreign key constraints are handled in getForeignKeys.
func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) {
finalQuery, err := isi.getConstraintsDQL()
if err != nil {
return nil, nil, nil, err
}
rows, err := isi.Db.Query(finalQuery, table.Schema, table.Name)
if err != nil {
return nil, nil, nil, err
}
defer rows.Close()
var primaryKeys []string
var checkKeys []schema.CheckConstraint
m := make(map[string][]string)
for rows.Next() {
if err := isi.processRow(rows, conv, &primaryKeys, &checkKeys, m); err != nil {
conv.Unexpected(fmt.Sprintf("Can't scan constrants. error: %v", err))
continue
}
}
return primaryKeys, checkKeys, m, nil
}
// getConstraintsDQL returns the appropriate SQL query based on the existence of CHECK_CONSTRAINTS.
func (isi InfoSchemaImpl) getConstraintsDQL() (string, error) {
var tableExistsCount int
// check if CHECK_CONSTRAINTS table exists.
checkQuery := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`
err := isi.Db.QueryRow(checkQuery).Scan(&tableExistsCount)
if err != nil {
return "", err
}
// mysql version 8.0.16 and above has CHECK_CONSTRAINTS table.
if tableExistsCount > 0 {
return `SELECT DISTINCT COALESCE(k.COLUMN_NAME,'') AS COLUMN_NAME,t.CONSTRAINT_NAME, t.CONSTRAINT_TYPE, COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t
LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k
ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME
AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA
AND t.TABLE_NAME = k.TABLE_NAME
LEFT JOIN INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c
ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME
AND t.TABLE_SCHEMA = c.CONSTRAINT_SCHEMA
WHERE t.TABLE_SCHEMA = ?
AND t.TABLE_NAME = ?;`, nil
}
return `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t
INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k
ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME
AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA
AND t.TABLE_NAME = k.TABLE_NAME
WHERE t.TABLE_SCHEMA = ?
AND t.TABLE_NAME = ?
ORDER BY k.ORDINAL_POSITION;`, nil
}
// processRow handles scanning and processing of a database row for GetConstraints.
func (isi InfoSchemaImpl) processRow(
rows *sql.Rows, conv *internal.Conv, primaryKeys *[]string,
checkKeys *[]schema.CheckConstraint, m map[string][]string,
) error {
var col, constraintType, checkClause, constraintName string
var err error
cols, err := rows.Columns()
if err != nil {
conv.Unexpected(fmt.Sprintf("Failed to get columns: %v", err))
return err
}
switch len(cols) {
case 2:
err = rows.Scan(&col, &constraintType)
case 4:
err = rows.Scan(&col, &constraintName, &constraintType, &checkClause)
default:
conv.Unexpected(fmt.Sprintf("unexpected number of columns: %d", len(cols)))
return fmt.Errorf("unexpected number of columns: %d", len(cols))
}
if err != nil {
return err
}
if col == "" && constraintType == "" {
conv.Unexpected("Got empty column or constraint type")
return nil
}
switch constraintType {
case "PRIMARY KEY":
*primaryKeys = append(*primaryKeys, col)
// Case added to handle check constraints
case "CHECK":
checkClause = collationRegex.ReplaceAllString(checkClause, "")
checkClause = checkAndAddParentheses(checkClause)
*checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, ExprId: internal.GenerateExpressionId(), Id: internal.GenerateCheckConstrainstId()})
default:
m[col] = append(m[col], constraintType)
}
return nil
}
// checkAndAddParentheses this method will check parentheses if found it will return same string
// or add the parentheses then return the string
func checkAndAddParentheses(checkClause string) string {
if strings.HasPrefix(checkClause, "(") && strings.HasSuffix(checkClause, ")") {
return checkClause
} else {
return `(` + checkClause + `)`
}
}
// GetForeignKeys return list all the foreign keys constraints.
// MySQL supports cross-database foreign key constraints. We ignore
// them because the Spanner migration tool works database at a time (a specific run
// of the Spanner migration tool focuses on a specific database) and so we can't handle
// them effectively.
func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.SchemaAndName) (foreignKeys []schema.ForeignKey, err error) {
q := `SELECT k.REFERENCED_TABLE_NAME,
k.COLUMN_NAME,
k.REFERENCED_COLUMN_NAME,
k.CONSTRAINT_NAME,
r.DELETE_RULE,
r.UPDATE_RULE
FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS AS r
INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k
ON r.CONSTRAINT_NAME = k.CONSTRAINT_NAME
AND r.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA
AND r.TABLE_NAME = k.TABLE_NAME
AND r.REFERENCED_TABLE_NAME = k.REFERENCED_TABLE_NAME
AND k.REFERENCED_TABLE_SCHEMA = k.TABLE_SCHEMA
WHERE k.TABLE_SCHEMA = ?
AND k.TABLE_NAME = ?
ORDER BY
k.REFERENCED_TABLE_NAME,
k.ORDINAL_POSITION;` //TODO(khajanchi): Add a UT for the change of removing column name from order by clause
rows, err := isi.Db.Query(q, table.Schema, table.Name)
if err != nil {
return nil, err
}
defer rows.Close()
var col, refCol, refTable, fKeyName, OnDelete, OnUpdate string
fKeys := make(map[string]common.FkConstraint)
var keyNames []string
for rows.Next() {
err := rows.Scan(&refTable, &col, &refCol, &fKeyName, &OnDelete, &OnUpdate)
if err != nil {
conv.Unexpected(fmt.Sprintf("Can't scan: %v", err))
continue
}
if _, found := fKeys[fKeyName]; found {
fk := fKeys[fKeyName]
fk.Cols = append(fk.Cols, col)
fk.Refcols = append(fk.Refcols, refCol)
fKeys[fKeyName] = fk
fk.OnDelete = OnDelete
fk.OnUpdate = OnUpdate
continue
}
fKeys[fKeyName] = common.FkConstraint{Name: fKeyName, Table: refTable, Refcols: []string{refCol}, Cols: []string{col}, OnDelete: OnDelete, OnUpdate: OnUpdate}
keyNames = append(keyNames, fKeyName)
}
sort.Strings(keyNames)
for _, k := range keyNames {
foreignKeys = append(foreignKeys,
schema.ForeignKey{
Id: internal.GenerateForeignkeyId(),
Name: fKeys[k].Name,
ColumnNames: fKeys[k].Cols,
ReferTableName: fKeys[k].Table,
ReferColumnNames: fKeys[k].Refcols,
OnDelete: fKeys[k].OnDelete,
OnUpdate: fKeys[k].OnUpdate,
})
}
return foreignKeys, nil
}
// GetIndexes return a list of all indexes for the specified table.
func (isi InfoSchemaImpl) GetIndexes(conv *internal.Conv, table common.SchemaAndName, colNameIdMap map[string]string) ([]schema.Index, error) {
q := `SELECT DISTINCT INDEX_NAME,COLUMN_NAME,SEQ_IN_INDEX,COLLATION,NON_UNIQUE
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_SCHEMA = ?
AND TABLE_NAME = ?
AND INDEX_NAME != 'PRIMARY'
ORDER BY INDEX_NAME, SEQ_IN_INDEX;`
rows, err := isi.Db.Query(q, table.Schema, table.Name)
if err != nil {
return nil, err
}
defer rows.Close()
var name, column, sequence, nonUnique string
var collation sql.NullString
indexMap := make(map[string]schema.Index)
var indexNames []string
var indexes []schema.Index
for rows.Next() {
if err := rows.Scan(&name, &column, &sequence, &collation, &nonUnique); err != nil {
conv.Unexpected(fmt.Sprintf("Can't scan: %v", err))
continue
}
if _, found := indexMap[name]; !found {
indexNames = append(indexNames, name)
indexMap[name] = schema.Index{
Id: internal.GenerateIndexesId(),
Name: name,
Unique: (nonUnique == "0"),
}
}
index := indexMap[name]
index.Keys = append(index.Keys, schema.Key{
ColId: colNameIdMap[column],
Desc: (collation.Valid && collation.String == "D"),
})
indexMap[name] = index
}
for _, k := range indexNames {
indexes = append(indexes, indexMap[k])
}
return indexes, nil
}
// StartChangeDataCapture is used for automatic triggering of Datastream job when
// performing a streaming migration.
func (isi InfoSchemaImpl) StartChangeDataCapture(ctx context.Context, conv *internal.Conv) (map[string]interface{}, error) {
mp := make(map[string]interface{})
var (
schemaDetails map[string]internal.SchemaDetails
err error
)
commonInfoSchema := common.InfoSchemaImpl{}
schemaDetails, err = commonInfoSchema.GetIncludedSrcTablesFromConv(conv)
streamingCfg, err := streaming.ReadStreamingConfig(isi.SourceProfile.Conn.Mysql.StreamingConfig, isi.TargetProfile.Conn.Sp.Dbname, schemaDetails)
if err != nil {
return nil, fmt.Errorf("error reading streaming config: %v", err)
}
pubsubCfg, err := streaming.CreatePubsubResources(ctx, isi.MigrationProjectId, streamingCfg.DatastreamCfg.DestinationConnectionConfig, isi.SourceProfile.Conn.Mysql.Db, constants.REGULAR_GCS)
if err != nil {
return nil, fmt.Errorf("error creating pubsub resources: %v", err)
}
streamingCfg.PubsubCfg = *pubsubCfg
dlqPubsubCfg, err := streaming.CreatePubsubResources(ctx, isi.MigrationProjectId, streamingCfg.DatastreamCfg.DestinationConnectionConfig, isi.SourceProfile.Conn.Mysql.Db, constants.DLQ_GCS)
if err != nil {
return nil, fmt.Errorf("error creating pubsub resources: %v", err)
}
streamingCfg.DlqPubsubCfg = *dlqPubsubCfg
streamingCfg, err = streaming.StartDatastream(ctx, isi.MigrationProjectId, streamingCfg, isi.SourceProfile, isi.TargetProfile, schemaDetails)
if err != nil {
err = fmt.Errorf("error starting datastream: %v", err)
return nil, err
}
mp["streamingCfg"] = streamingCfg
return mp, err
}
// StartStreamingMigration is used for automatic triggering of Dataflow job when
// performing a streaming migration.
func (isi InfoSchemaImpl) StartStreamingMigration(ctx context.Context, migrationProjectId string, client *sp.Client, conv *internal.Conv, streamingInfo map[string]interface{}) (internal.DataflowOutput, error) {
streamingCfg, _ := streamingInfo["streamingCfg"].(streaming.StreamingCfg)
dfOutput, err := streaming.StartDataflow(ctx, migrationProjectId, isi.TargetProfile, streamingCfg, conv)
if err != nil {
err = fmt.Errorf("error starting dataflow: %v", err)
return internal.DataflowOutput{}, err
}
return dfOutput, nil
}
func toType(dataType string, columnType string, charLen sql.NullInt64, numericPrecision, numericScale sql.NullInt64) schema.Type {
switch {
case dataType == "set":
return schema.Type{Name: dataType, ArrayBounds: []int64{-1}}
case charLen.Valid:
return schema.Type{Name: dataType, Mods: []int64{charLen.Int64}}
// We only want to parse the length for tinyints when it is present, in the form tinyint(12). columnType can also be just 'tinyint',
// in which case we skip this parsing.
case dataType == "tinyint" && len(columnType) > len("tinyint"):
var length int64
_, err := fmt.Sscanf(columnType, "tinyint(%d)", &length)
if err != nil {
return schema.Type{Name: dataType}
}
return schema.Type{Name: dataType, Mods: []int64{length}}
case numericPrecision.Valid && numericScale.Valid && numericScale.Int64 != 0:
return schema.Type{Name: dataType, Mods: []int64{numericPrecision.Int64, numericScale.Int64}}
case numericPrecision.Valid:
return schema.Type{Name: dataType, Mods: []int64{numericPrecision.Int64}}
default:
return schema.Type{Name: dataType}
}
}
// buildVals constructs []sql.RawBytes value containers to scan row
// results into. Returns both the underlying containers (as a slice)
// as well as an interface{} of pointers to containers to pass to
// rows.Scan.
func buildVals(n int) (v []sql.RawBytes, iv []interface{}) {
v = make([]sql.RawBytes, n)
// rows.Scan wants '[]interface{}' as an argument, so we must copy the
// references into such a slice.
iv = make([]interface{}, len(v))
for i := range v {
iv[i] = &v[i]
}
return v, iv
}
func valsToStrings(vals []sql.RawBytes) []string {
toString := func(val sql.RawBytes) string {
if val == nil {
return "NULL"
}
return string(val)
}
var s []string
for _, v := range vals {
s = append(s, toString(v))
}
return s
}
func createSequence(conv *internal.Conv) ddl.Sequence {
id := internal.GenerateSequenceId()
sequenceName := "Sequence" + id[1:]
sequence := ddl.Sequence{
Id: id,
Name: sequenceName,
SequenceKind: "BIT REVERSED SEQUENCE",
}
conv.ConvLock.Lock()
defer conv.ConvLock.Unlock()
srcSequences := conv.SrcSequences
srcSequences[id] = sequence
return sequence
}