schema/table_schema.go (89 lines of code) (raw):

// Copyright (c) 2017 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package schema import ( "database/sql" "fmt" "regexp" "strings" "github.com/pkg/errors" "github.com/uber/storagetapper/db" "github.com/uber/storagetapper/log" "github.com/uber/storagetapper/types" ) //HasPrimaryKey checks if given table has primary key func HasPrimaryKey(s *types.TableSchema) bool { for _, c := range s.Columns { if c.Key == "PRI" { return true } } return false } func getRawLow(db *sql.DB, fullTable string) (string, error) { var ct, unused string /*FIXME: Can I pass nil here? */ if err := db.QueryRow("SHOW CREATE TABLE "+fullTable).Scan(&unused, &ct); err != nil { return "", err } //Cut CONSTRAINTS var re = regexp.MustCompile(`(?im)^\s*CONSTRAINT.*\n`) ct = re.ReplaceAllString(ct, ``) var re1 = regexp.MustCompile(`,\n\)`) // Cut possible remaining comma at the end ct = re1.ReplaceAllString(ct, ` )`) //FIXME: Is there a better way to insert \n i := strings.Index(ct, "(") if i == -1 { return "", errors.New("Broken schema: " + ct) } return ct[i:], nil } // GetRaw returns output of SHOW CREATE TABLE after the "CREATE TABLE xyz (" // So called "raw" schema func GetRaw(dbl *db.Loc, fullTable string, inputType string) (string, error) { conn, err := db.OpenService(dbl, "", inputType) if err != nil { return "", err } defer func() { log.E(conn.Close()) }() return getRawLow(conn, fullTable) } // Get loads structured schema for "table", from master DB, identified by dbl func Get(dbl *db.Loc, table string, inputType string) (*types.TableSchema, error) { conn, err := db.OpenService(dbl, "information_schema", inputType) if err != nil { return nil, err } defer func() { log.E(conn.Close()) }() return GetColumns(conn, dbl.Name, table) } // ParseColumnInfo reads parses column information into table schema func ParseColumnInfo(rows *sql.Rows, dbName, table string) (*types.TableSchema, error) { tableSchema := types.TableSchema{DBName: dbName, TableName: table, Columns: []types.ColumnSchema{}} rowsCount := 0 for rows.Next() { cs := types.ColumnSchema{} err := rows.Scan(&cs.Name, &cs.OrdinalPosition, &cs.IsNullable, &cs.DataType, &cs.CharacterMaximumLength, &cs.NumericPrecision, &cs.NumericScale, &cs.Type, &cs.Key) if err != nil { log.E(errors.Wrap(err, fmt.Sprintf("Error scanning table schema query result for %s.%s", dbName, table))) return nil, err } tableSchema.Columns = append(tableSchema.Columns, cs) rowsCount++ } if err := rows.Err(); err != nil { log.F(err) } if rowsCount == 0 { return &tableSchema, fmt.Errorf("no schema columns for table %s.%s. check grants", dbName, table) } log.Debugf("Got schema from state for '%v.%v' = '%+v'", dbName, table, tableSchema) return &tableSchema, nil } // GetColumns reads structured schema from information_schema for table from given connection func GetColumns(conn *sql.DB, dbName string, tableName string) (*types.TableSchema, error) { query := "SELECT COLUMN_NAME, ORDINAL_POSITION, IS_NULLABLE, DATA_TYPE, " + "CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE, COLUMN_TYPE, " + "COLUMN_KEY FROM information_schema.columns WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? " + "ORDER BY ORDINAL_POSITION" log.Debugf("%v %v %v", query, dbName, tableName) rows, err := conn.Query(query, dbName, tableName) if log.E(errors.Wrap(err, fmt.Sprintf("Error fetching table schema for %s.%s", dbName, tableName))) { return nil, err } defer func() { log.E(rows.Close()) }() return ParseColumnInfo(rows, dbName, tableName) }