sources/postgres/pgdump.go (135 lines of code) (raw):

// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package postgres import ( "fmt" "reflect" "strconv" "strings" "time" pg_query "github.com/pganalyze/pg_query_go/v6" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" ) // DbDumpImpl Postgres specific implementation for DdlDumpImpl. type DbDumpImpl struct { } type copyOrInsert struct { stmt stmtType table string cols []string rows [][]string // Empty for COPY-FROM. } type stmtType int const ( copyFrom stmtType = iota insert ) // GetToDdl functions below implement the common.DbDump interface func (ddi DbDumpImpl) GetToDdl() common.ToDdl { return ToDdlImpl{} } // ProcessDump calls processPgDump to read a Postgres dump file func (ddi DbDumpImpl) ProcessDump(conv *internal.Conv, r *internal.Reader) error { return processPgDump(conv, r) } // processPgDump reads pg_dump data from r and does schema or data conversion, // depending on whether conv is configured for schema mode or data mode. // In schema mode, ProcessPgDump incrementally builds a schema (updating conv). // In data mode, ProcessPgDump uses this schema to convert PostgreSQL data // and writes it to Spanner, using the data sink specified in conv. func processPgDump(conv *internal.Conv, r *internal.Reader) error { for { startLine := r.LineNumber startOffset := r.Offset b, stmts, err := readAndParseChunk(conv, r) if err != nil { return err } ci := processStatements(conv, stmts) internal.VerbosePrintf("Parsed SQL command at line=%d/fpos=%d: %d stmts (%d lines, %d bytes) ci=%v\n", startLine, startOffset, len(stmts), r.LineNumber-startLine, len(b), ci != nil) logger.Log.Debug(fmt.Sprintf("Parsed SQL command at line=%d/fpos=%d: %d stmts (%d lines, %d bytes) ci=%v\n", startLine, startOffset, len(stmts), r.LineNumber-startLine, len(b), ci != nil)) if ci != nil { switch ci.stmt { case copyFrom: commonColIds, err := common.PrepareColumns(conv, ci.table, ci.cols) if err != nil && !conv.SchemaMode() { return err } processCopyBlock(conv, ci.table, commonColIds, ci.cols, r) case insert: if conv.SchemaMode() { continue } // Handle INSERT statements where columns are not // specified i.e. an insert for all table columns. var colNames []string if len(ci.cols) == 0 { for _, col := range conv.SrcSchema[ci.table].ColIds { colNames = append(colNames, conv.SrcSchema[ci.table].ColDefs[col].Name) } } else { colNames = ci.cols } commonColIds, err := common.PrepareColumns(conv, ci.table, colNames) if err != nil { return err } colNameIdMap := internal.GetSrcColNameIdMap(conv.SrcSchema[ci.table]) for _, vals := range ci.rows { newVals, err := common.PrepareValues(conv, ci.table, colNameIdMap, commonColIds, colNames, vals) if err != nil { srcTableName := conv.SrcSchema[ci.table].Name conv.Unexpected(fmt.Sprintf("Error while converting data: %s\n", err)) conv.StatsAddBadRow(srcTableName, conv.DataMode()) conv.CollectBadRow(srcTableName, colNames, vals) continue } ProcessDataRow(conv, ci.table, commonColIds, newVals) } } } if r.EOF { break } } internal.ResolveForeignKeyIds(conv.SrcSchema) return nil } // readAndParseChunk parses a chunk of pg_dump data, returning the bytes read, // the parsed AST (nil if nothing read), and whether we've hit end-of-file. func readAndParseChunk(conv *internal.Conv, r *internal.Reader) ([]byte, []*pg_query.RawStmt, error) { var l [][]byte for { b := r.ReadLine() l = append(l, b) // If we see a semicolon or eof, we're likely to have a command, so try to parse it. // Note: we could just parse every iteration, but that would mean more attempts at parsing. if strings.Contains(string(b), ";") || r.EOF { n := 0 for i := range l { n += len(l[i]) } s := make([]byte, n) n = 0 for i := range l { n += copy(s[n:], l[i]) } tree, err := pg_query.Parse(string(s)) if err == nil { return s, tree.Stmts, nil } // Likely causes of failing to parse: // a) complex statements with embedded semicolons e.g. 'CREATE FUNCTION' // b) a semicolon embedded in a multi-line comment, or // c) a semicolon embedded a string constant or column/table name. // We deal with this case by reading another line and trying again. conv.Stats.Reparsed++ } if r.EOF { return nil, nil, fmt.Errorf("error parsing last %d line(s) of input", len(l)) } } } func processCopyBlock(conv *internal.Conv, tableId string, commonColIds, srcCols []string, r *internal.Reader) { srcTableName := conv.SrcSchema[tableId].Name internal.VerbosePrintf("Parsing COPY-FROM stdin block starting at line=%d/fpos=%d\n", r.LineNumber, r.Offset) logger.Log.Debug(fmt.Sprintf("Parsing COPY-FROM stdin block starting at line=%d/fpos=%d\n", r.LineNumber, r.Offset)) for { b := r.ReadLine() if string(b) == "\\.\n" || string(b) == "\\.\r\n" { internal.VerbosePrintf("Parsed COPY-FROM stdin block ending at line=%d/fpos=%d\n", r.LineNumber, r.Offset) logger.Log.Debug(fmt.Sprintf("Parsed COPY-FROM stdin block ending at line=%d/fpos=%d\n", r.LineNumber, r.Offset)) return } if r.EOF { conv.Unexpected("Reached eof while parsing copy-block") return } conv.StatsAddRow(srcTableName, conv.SchemaMode()) // We have to read the copy-block data so that we can process the remaining // pg_dump content. However, if we don't want the data, stop here. // In particular, avoid the strings.Split and ProcessDataRow calls below, which // will be expensive for huge datasets. if !conv.DataMode() { continue } // pg_dump escapes backslash in copy-block statements. For example: // a) a\"b becomes a\\"b in COPY-BLOCK (but 'a\"b' in INSERT-INTO) // b) {"a\"b"} becomes {"a\\"b"} in COPY-BLOCK (but '{"a\"b"}' in INSERT-INTO) // Note: a'b and {a'b} are unchanged in COPY-BLOCK and INSERT-INTO. s := strings.ReplaceAll(string(b), `\\`, `\`) // COPY-FROM blocks use tabs to separate data items. Note that space within data // items is significant e.g. if a table row contains data items "a ", " b " // it will be shown in the COPY-FROM block as "a \t b ". values := strings.Split(strings.Trim(s, "\r\n"), "\t") colNameIdMap := internal.GetSrcColNameIdMap(conv.SrcSchema[tableId]) newValues, err := common.PrepareValues(conv, tableId, colNameIdMap, commonColIds, srcCols, values) if err != nil { conv.Unexpected(fmt.Sprintf("Error while converting data: %s\n", err)) conv.StatsAddBadRow(srcTableName, conv.DataMode()) conv.CollectBadRow(srcTableName, srcCols, values) continue } ProcessDataRow(conv, tableId, commonColIds, newValues) } } // processStatements extracts schema information and data from PostgreSQL // statements, updating Conv with new schema information, and returning // copyOrInsert if a COPY-FROM or INSERT statement is encountered. // Note that the actual parsing/processing of COPY-FROM data blocks is // handled elsewhere (see process.go). func processStatements(conv *internal.Conv, rawStmts []*pg_query.RawStmt) *copyOrInsert { // Typically we'll have only one statement, but we handle the general case. for i, rawStmt := range rawStmts { node := rawStmt.Stmt switch n := node.GetNode().(type) { case *pg_query.Node_AlterTableStmt: if conv.SchemaMode() { processAlterTableStmt(conv, n.AlterTableStmt) } case *pg_query.Node_CopyStmt: if i != len(rawStmts)-1 { conv.Unexpected("CopyFrom is not the last statement in batch: ignoring following statements") conv.ErrorInStatement(printNodeType(n.CopyStmt)) } return processCopyStmt(conv, n.CopyStmt) case *pg_query.Node_CreateStmt: if conv.SchemaMode() { processCreateStmt(conv, n.CreateStmt) } case *pg_query.Node_InsertStmt: return processInsertStmt(conv, n.InsertStmt) case *pg_query.Node_VariableSetStmt: if conv.SchemaMode() { processVariableSetStmt(conv, n.VariableSetStmt) } case *pg_query.Node_IndexStmt: if conv.SchemaMode() { processIndexStmt(conv, n.IndexStmt) } default: conv.SkipStatement(printNodeType(n)) } } return nil } func processIndexStmt(conv *internal.Conv, n *pg_query.IndexStmt) { if n.Relation == nil { logStmtError(conv, n, fmt.Errorf("cannot process index statement with nil relation")) return } tableName, err := getTableName(conv, n.Relation) if err != nil { logStmtError(conv, n, fmt.Errorf("can't get table name: %w", err)) return } if tbl, ok := internal.GetSrcTableByName(conv.SrcSchema, tableName); ok { ctable := conv.SrcSchema[tbl.Id] ctable.Indexes = append(ctable.Indexes, schema.Index{ Id: internal.GenerateIndexesId(), Name: n.Idxname, Unique: n.Unique, Keys: toIndexKeys(conv, n.Idxname, n.IndexParams, ctable.ColNameIdMap), }) conv.SrcSchema[tbl.Id] = ctable } else { conv.Unexpected(fmt.Sprintf("Table %s not found while processing index statement", tableName)) conv.SkipStatement(printNodeType(n)) } } func processAlterTableStmt(conv *internal.Conv, n *pg_query.AlterTableStmt) { if n.Relation == nil { logStmtError(conv, n, fmt.Errorf("relation is nil")) return } tableName, err := getTableName(conv, n.Relation) if err != nil { logStmtError(conv, n, fmt.Errorf("can't get table name: %w", err)) return } if tbl, ok := internal.GetSrcTableByName(conv.SrcSchema, tableName); ok { for _, i := range n.Cmds { cmd := i.GetNode() switch t := cmd.(type) { case *pg_query.Node_AlterTableCmd: a := t.AlterTableCmd switch { case a.Subtype == pg_query.AlterTableType_AT_SetNotNull && a.Name != "": c := constraint{ct: pg_query.ConstrType_CONSTR_NOTNULL, cols: []string{a.Name}} updateSchema(conv, tbl.Id, []constraint{c}, "ALTER TABLE") conv.SchemaStatement(strings.Join([]string{printNodeType(n), printNodeType(t)}, ".")) case a.Subtype == pg_query.AlterTableType_AT_AddConstraint && a.Def != nil: switch at := a.Def.GetNode().(type) { case *pg_query.Node_Constraint: updateSchema(conv, tbl.Id, extractConstraints(conv, printNodeType(n), tableName, []*pg_query.Node{a.Def}), "ALTER TABLE") conv.SchemaStatement(strings.Join([]string{printNodeType(n), printNodeType(t), printNodeType(at)}, ".")) default: conv.SkipStatement(strings.Join([]string{printNodeType(n), printNodeType(t), printNodeType(at)}, ".")) } default: conv.SkipStatement(strings.Join([]string{printNodeType(n), printNodeType(t)}, ".")) } default: conv.SkipStatement(strings.Join([]string{printNodeType(n), printNodeType(t)}, ".")) } } } else { // In PostgreSQL, AlterTable statements can be applied to views, // sequences and indexes in addition to tables. Since we only // track tables created by "CREATE TABLE", this lookup can fail. // For debugging purposes we log the lookup failure if we're // in verbose mode, but otherwise we just skip these statements. conv.SkipStatement(printNodeType(n)) internal.VerbosePrintf("Processing %v statement: table %s not found", printNodeType(n), tableName) logger.Log.Debug(fmt.Sprintf("Processing %v statement: table %s not found", printNodeType(n), tableName)) } } func processCreateStmt(conv *internal.Conv, n *pg_query.CreateStmt) { colDef := make(map[string]schema.Column) if n.Relation == nil { logStmtError(conv, n, fmt.Errorf("relation is nil")) return } table, err := getTableName(conv, n.Relation) if err != nil { logStmtError(conv, n, fmt.Errorf("can't get table name: %w", err)) return } if len(n.InhRelations) > 0 { // Skip inherited tables. conv.SkipStatement(printNodeType(n)) conv.Unexpected(fmt.Sprintf("Found inherited table %s -- we do not currently handle inherited tables", table)) internal.VerbosePrintf("Processing %v statement: table %s is inherited table", printNodeType(n), table) logger.Log.Debug(fmt.Sprintf("Processing %v statement: table %s is inherited table", printNodeType(n), table)) return } var constraints []constraint var colIds []string colNameIdMap := make(map[string]string) for _, te := range n.TableElts { switch te.GetNode().(type) { case *pg_query.Node_ColumnDef: _, col, cdConstraints, err := processColumn(conv, te.GetColumnDef(), table) if err != nil { logStmtError(conv, n, err) return } col.Id = internal.GenerateColumnId() colDef[col.Id] = col colIds = append(colIds, col.Id) colNameIdMap[col.Name] = col.Id constraints = append(constraints, cdConstraints...) case *pg_query.Node_Constraint: // Note: there should be at most one Constraint node in // n.TableElts. We don't check this. We just keep collecting // constraints. constraints = append(constraints, extractConstraints(conv, printNodeType(n), table, []*pg_query.Node{te})...) default: conv.Unexpected(fmt.Sprintf("Found %s node while processing CreateStmt TableElts", printNodeType(te))) } } conv.SchemaStatement(printNodeType(n)) tableId := internal.GenerateTableId() conv.SrcSchema[tableId] = schema.Table{ Id: tableId, Name: table, ColIds: colIds, ColNameIdMap: colNameIdMap, ColDefs: colDef, } // Note: constraints contains all info about primary keys, not-null keys // and foreign keys. updateSchema(conv, tableId, constraints, "CREATE TABLE") } func processColumn(conv *internal.Conv, n *pg_query.ColumnDef, table string) (string, schema.Column, []constraint, error) { mods := getTypeMods(conv, n.TypeName.Typmods) if n.Colname == "" { return "", schema.Column{}, nil, fmt.Errorf("colname is empty string") } name := n.Colname tid, err := getTypeID(n.TypeName.Names) if err != nil { return "", schema.Column{}, nil, fmt.Errorf("can't get type id for %s: %w", name, err) } ty := schema.Type{ Name: tid, Mods: mods, ArrayBounds: getArrayBounds(conv, n.TypeName.ArrayBounds)} return name, schema.Column{Name: name, Type: ty}, analyzeColDefConstraints(conv, printNodeType(n), table, n.Constraints, name), nil } func processInsertStmt(conv *internal.Conv, n *pg_query.InsertStmt) *copyOrInsert { if n.Relation == nil { logStmtError(conv, n, fmt.Errorf("relation is nil")) return nil } table, err := getTableName(conv, n.Relation) if err != nil { logStmtError(conv, n, fmt.Errorf("can't get table name: %w", err)) return nil } tableId, _ := internal.GetTableIdFromSrcName(conv.SrcSchema, table) if _, ok := conv.SrcSchema[tableId]; !ok { // If we don't have schema information for a table, we drop all insert // statements for it. The most likely reason we don't have schema information // for a table is that it is an inherited table - we skip all inherited tables. conv.SkipStatement(printNodeType(n)) internal.VerbosePrintf("Processing %v statement: table %s not found", printNodeType(n), table) logger.Log.Debug(fmt.Sprintf("Processing %v statement: table %s is inherited table", printNodeType(n), table)) return nil } conv.StatsAddRow(tableId, conv.SchemaMode()) colNames, err := getCols(conv, table, n.Cols) if err != nil { logStmtError(conv, n, fmt.Errorf("can't get col name: %w", err)) conv.StatsAddBadRow(table, conv.SchemaMode()) return nil } switch sel := n.SelectStmt.GetNode().(type) { case *pg_query.Node_SelectStmt: rows := getRows(conv, sel.SelectStmt.ValuesLists, n) conv.DataStatement(printNodeType(sel)) if conv.DataMode() { return &copyOrInsert{stmt: insert, table: tableId, cols: colNames, rows: rows} } default: conv.Unexpected(fmt.Sprintf("Found %s node while processing InsertStmt SelectStmt", printNodeType(sel))) } return nil } func processCopyStmt(conv *internal.Conv, n *pg_query.CopyStmt) *copyOrInsert { // Always return a copyOrInsert{stmt: copyFrom, ...} even if we // encounter errors. Otherwise we won't be able to parse // the data portion of the COPY-FROM statement, and we'll // likely get stuck at this point in the pg_dump file. table := "BOGUS_COPY_FROM_TABLE" var err error if n.Relation != nil { table, err = getTableName(conv, n.Relation) if err != nil { conv.Unexpected(fmt.Sprintf("Processing %v statement: %s", printNodeType(n), err)) } } else { logStmtError(conv, n, fmt.Errorf("relation is nil")) } if !conv.SchemaMode() { table, _ = internal.GetTableIdFromSrcName(conv.SrcSchema, table) } if _, ok := conv.SrcSchema[table]; !ok { // If we don't have schema information for a table, we drop all copy // statements for it. The most likely reason we don't have schema information // for a table is that it is an inherited table - we skip all inherited tables. conv.SkipStatement(printNodeType(n)) internal.VerbosePrintf("Processing %v statement: table %s not found", printNodeType(n), table) logger.Log.Debug(fmt.Sprintf("Processing %v statement: table %s is inherited table", printNodeType(n), table)) return &copyOrInsert{stmt: copyFrom, table: table, cols: []string{}} } var cols []string for _, a := range n.Attlist { s, err := getString(a) if err != nil { conv.Unexpected(fmt.Sprintf("Processing %v statement Attlist: %s", printNodeType(n), err)) s = "BOGUS_COPY_FROM_COLUMN" } cols = append(cols, s) } conv.DataStatement(printNodeType(n)) return &copyOrInsert{stmt: copyFrom, table: table, cols: cols} } func processVariableSetStmt(conv *internal.Conv, n *pg_query.VariableSetStmt) { if n.Name == "timezone" { if len(n.Args) == 1 { arg := n.Args[0] switch c := arg.GetNode().(type) { case *pg_query.Node_AConst: tz := c.AConst.GetSval() if tz == nil { logStmtError(conv, c, fmt.Errorf("can't get timezone Arg")) return } loc, err := time.LoadLocation(tz.Sval) if err != nil { logStmtError(conv, c, err) return } conv.SetLocation(loc) default: logStmtError(conv, arg, fmt.Errorf("found %s node in Arg", printNodeType(c))) return } } } } func getTypeMods(conv *internal.Conv, t []*pg_query.Node) (l []int64) { for _, x := range t { switch t1 := x.GetNode().(type) { case *pg_query.Node_AConst: switch t2 := (t1.AConst.Val).(type) { case *pg_query.A_Const_Ival: l = append(l, int64(t2.Ival.Ival)) default: conv.Unexpected(fmt.Sprintf("Found %s node while processing Typmods", printNodeType(t2))) } default: conv.Unexpected(fmt.Sprintf("Found %s node while processing Typmods", printNodeType(t1))) } } return l } func getArrayBounds(conv *internal.Conv, t []*pg_query.Node) (l []int64) { for _, x := range t { switch t := x.GetNode().(type) { case *pg_query.Node_Integer: // 'Ival' provides the array bound (-1 for an array where bound is not specified). l = append(l, int64(t.Integer.Ival)) default: conv.Unexpected(fmt.Sprintf("Found %s node while processing ArrayBounds", printNodeType(x))) } } return l } func getTypeID(nodes []*pg_query.Node) (string, error) { // The pg_query library generates a pg_catalog schema prefix for most // types, but not for all. Typically "aliases" don't have the prefix. // For example, "boolean" is parsed to ["pg_catalog", "bool"], but "bool" is // parsed to ["bool"]. However the exact rules are unclear e.g. "date" // is parsed to just ["date"]. // For simplicity, we strip off the pg_catalog prefix. var ids []string for _, node := range nodes { s, err := getString(node) if err != nil { return "", err } ids = append(ids, s) } if len(ids) > 1 && ids[0] == "pg_catalog" { ids = ids[1:] } return strings.Join(ids, "."), nil } // getTableName extracts the table name from RangeVar n, and returns // the raw extracted name (the PostgreSQL table name). func getTableName(conv *internal.Conv, n *pg_query.RangeVar) (string, error) { // RangeVar is used to represent table names. It consists of three components: // Catalogname: database name; either not specified or the current database // Schemaname: schemas are PostgreSql namepaces; often unspecified; defaults to "public" // Relname: name of the table // We build a table name from these three components as follows: // a) nil components are dropped. // b) if more than one component is specified, they are joined using "." // (Note that Spanner doesn't allow "." in table names, so this // will eventually get re-mapped when we construct the Spanner table name). // c) Schemaname is dropped if it is "public". // d) return error if Relname is nil or "". var l []string if n.Catalogname != "" { l = append(l, n.Catalogname) } if n.Schemaname != "" && n.Schemaname != "public" { // Don't include "public". l = append(l, n.Schemaname) } if n.Relname == "" { return "", fmt.Errorf("relname is empty: can't build table name") } l = append(l, n.Relname) return strings.Join(l, "."), nil } type constraint struct { ct pg_query.ConstrType cols []string name string // Used for FOREIGN KEY or SECONDARY INDEX /* Fields used for FOREIGN KEY constraints: */ referCols []string referTable string onDelete string onUpdate string } // extractConstraints traverses a list of nodes (expecting them to be // Constraint nodes), and collects the constraints they represent. func extractConstraints(conv *internal.Conv, stmtType, table string, l []*pg_query.Node) (cs []constraint) { for _, i := range l { switch d := i.GetNode().(type) { case *pg_query.Node_Constraint: c := d.Constraint var cols, referCols []string var referTable, onDelete, onUpdate string var conName string switch c.Contype { case pg_query.ConstrType_CONSTR_FOREIGN: t, err := getTableName(conv, c.Pktable) if err != nil { conv.Unexpected(fmt.Sprintf("Processing %v statement: error processing constraints: %s", printNodeType(d), err.Error())) conv.ErrorInStatement(printNodeType(d)) continue } referTable = t if c.Conname != "" { conName = c.Conname } for _, attr := range c.FkAttrs { k, err := getString(attr) if err != nil { conv.Unexpected(fmt.Sprintf("Processing %v statement: error processing constraints: %s", printNodeType(d), err.Error())) conv.ErrorInStatement(printNodeType(d)) continue } cols = append(cols, k) } for _, attr := range c.PkAttrs { f, err := getString(attr) if err != nil { conv.Unexpected(fmt.Sprintf("Processing %v statement: error processing constraints: %s", printNodeType(d), err.Error())) conv.ErrorInStatement(printNodeType(d)) continue } referCols = append(referCols, f) } onDelete = c.GetFkDelAction() switch onDelete { case "a": onDelete = constants.FK_NO_ACTION case "r": onDelete = constants.FK_RESTRICT case "c": onDelete = constants.FK_CASCADE case "n": onDelete = constants.FK_SET_NULL case "d": onDelete = constants.FK_SET_DEFAULT case " ": onDelete = constants.FK_NO_ACTION default: onDelete = "UNKNOWN" } onUpdate = c.GetFkUpdAction() switch onUpdate { case "a": onUpdate = constants.FK_NO_ACTION case "r": onUpdate = constants.FK_RESTRICT case "c": onUpdate = constants.FK_CASCADE case "n": onUpdate = constants.FK_SET_NULL case "d": onUpdate = constants.FK_SET_DEFAULT case " ": onUpdate = constants.FK_NO_ACTION default: onUpdate = "UNKNOWN" } default: if c.Conname != "" { conName = c.Conname } for _, key := range c.Keys { k, err := getString(key) if err != nil { conv.Unexpected(fmt.Sprintf("Processing %v statement: error processing constraints: %s", printNodeType(d), err.Error())) conv.ErrorInStatement(fmt.Sprintf("%v.%v", printNodeType(i), printNodeType(d))) continue } cols = append(cols, k) } } cs = append(cs, constraint{ct: c.Contype, cols: cols, name: conName, referCols: referCols, referTable: referTable, onDelete: onDelete, onUpdate: onUpdate}) default: conv.Unexpected(fmt.Sprintf("Processing %v statement: found %s node while processing constraints\n", stmtType, printNodeType(d))) } } return cs } // analyzeColDefConstraints is like extractConstraints, but is specifially for // ColDef constraints. These constraints don't specify a key since they // are constraints for the column defined by ColDef. func analyzeColDefConstraints(conv *internal.Conv, stmtType, table string, l []*pg_query.Node, pgCol string) (cs []constraint) { // Do generic constraint processing and then set the keys of each constraint // to {pgCol}. for _, c := range extractConstraints(conv, stmtType, table, l) { if len(c.cols) != 0 { conv.Unexpected("ColumnDef constraint has keys") } c.cols = []string{pgCol} cs = append(cs, c) } return cs } // updateSchema updates the schema for table based on the given constraints. // 's' is the statement type being processed, and is used for debug messages. func updateSchema(conv *internal.Conv, tableId string, cs []constraint, stmtType string) { colNameIdMap := conv.SrcSchema[tableId].ColNameIdMap for _, c := range cs { switch c.ct { case pg_query.ConstrType_CONSTR_PRIMARY: ct := conv.SrcSchema[tableId] checkEmpty(conv, ct.PrimaryKeys, stmtType) ct.PrimaryKeys = toSchemaKeys(conv, tableId, c.cols, colNameIdMap) // Drop any previous primary keys. // In Spanner, primary key columns are usually annotated with NOT NULL, // but this can be omitted to allow NULL values in key columns. // In PostgreSQL, the primary key constraint is a combination of // NOT NULL and UNIQUE i.e. primary keys must be NOT NULL. // We preserve PostgreSQL semantics and enforce NOT NULL. updateCols(pg_query.ConstrType_CONSTR_NOTNULL, c.cols, ct.ColDefs, colNameIdMap) conv.SrcSchema[tableId] = ct case pg_query.ConstrType_CONSTR_FOREIGN: ct := conv.SrcSchema[tableId] ct.ForeignKeys = append(ct.ForeignKeys, toForeignKeys(c)) // Append to previous foreign keys. conv.SrcSchema[tableId] = ct case pg_query.ConstrType_CONSTR_UNIQUE: // Convert unique column constraint in postgres to a corresponding unique index in Spanner since // Spanner doesn't support unique constraints on columns. // TODO: Avoid Spanner-specific schema transformations in this file -- they should only // appear in toddl.go. This file should focus on generic transformation from source // database schemas into schema.go. ct := conv.SrcSchema[tableId] ct.Indexes = append(ct.Indexes, schema.Index{Name: c.name, Unique: true, Keys: toSchemaKeys(conv, tableId, c.cols, colNameIdMap)}) conv.SrcSchema[tableId] = ct default: ct := conv.SrcSchema[tableId] updateCols(c.ct, c.cols, ct.ColDefs, colNameIdMap) conv.SrcSchema[tableId] = ct } } } // updateCols updates colDef with new constraints. Specifically, we apply // 'ct' to each column in colNames. func updateCols(ct pg_query.ConstrType, colNames []string, colDef map[string]schema.Column, colNameIdMap map[string]string) { // TODO: add cases for other constraints. for _, cn := range colNames { cid := colNameIdMap[cn] cd := colDef[cid] switch ct { case pg_query.ConstrType_CONSTR_NOTNULL: cd.NotNull = true case pg_query.ConstrType_CONSTR_DEFAULT: cd.Ignored.Default = true } colDef[cid] = cd } } // toSchemaKeys converts a string list of PostgreSQL primary keys to // schema primary keys. func toSchemaKeys(conv *internal.Conv, tableId string, colNames []string, colNameIdMap map[string]string) (l []schema.Key) { for _, cn := range colNames { // PostgreSQL primary keys have no notation of ascending/descending. // We map them all into ascending primarary keys. l = append(l, schema.Key{ColId: colNameIdMap[cn]}) } return l } // toIndexKeys converts a list of PostgreSQL index keys to schema index keys. func toIndexKeys(conv *internal.Conv, idxName string, s []*pg_query.Node, colNameIdMap map[string]string) (l []schema.Key) { for _, k := range s { switch e := k.GetNode().(type) { case *pg_query.Node_IndexElem: if e.IndexElem.Name == "" { conv.Unexpected(fmt.Sprintf("Failed to process index %s: empty index column name", idxName)) continue } desc := false if e.IndexElem.Ordering == pg_query.SortByDir_SORTBY_DESC { desc = true } l = append(l, schema.Key{ColId: colNameIdMap[e.IndexElem.Name], Desc: desc}) } } return } // toForeignKeys converts a string list of PostgreSQL foreign keys to schema // foreign keys. func toForeignKeys(fk constraint) (fkey schema.ForeignKey) { fkey = schema.ForeignKey{ Id: internal.GenerateForeignkeyId(), Name: fk.name, ColumnNames: fk.cols, ReferTableName: fk.referTable, ReferColumnNames: fk.referCols, OnDelete: fk.onDelete, OnUpdate: fk.onUpdate, } return fkey } // getCols extracts and returns the column names for an InsertStatement. func getCols(conv *internal.Conv, table string, nodes []*pg_query.Node) (cols []string, err error) { for _, n := range nodes { switch r := n.GetNode().(type) { case *pg_query.Node_ResTarget: if r.ResTarget.Name != "" { cols = append(cols, r.ResTarget.Name) } default: return nil, fmt.Errorf("expecting ResTarget node but got %v node while processing Cols", printNodeType(r)) } } return cols, nil } // getRows extracts and returns the rows for an InsertStatement. func getRows(conv *internal.Conv, vll []*pg_query.Node, n *pg_query.InsertStmt) (rows [][]string) { for _, vl := range vll { var values []string switch vals := vl.GetNode().(type) { case *pg_query.Node_List: for _, v := range vals.List.Items { switch val := v.GetNode().(type) { case *pg_query.Node_AConst: if val.AConst.Isnull { values = append(values, "NULL") } else { switch c := val.AConst.Val.(type) { // Most data is dumped enclosed in quotes ('') lke 'abc', '12:30:45' etc which is classified // as type Node_String_ by the parser. Some data might not be quoted like (NULL, 14.67) and // the type assigned to them is Node_Null and Node_Float respectively. case *pg_query.A_Const_Sval: values = append(values, trimString(c.Sval)) case *pg_query.A_Const_Ival: // For uniformity, convert to string and handle everything in // dataConversion(). If performance of insert statements becomes a // high priority (it isn't right now), then consider preserving int64 // here to avoid the int64 -> string -> int64 conversions. values = append(values, strconv.FormatInt(int64(c.Ival.Ival), 10)) case *pg_query.A_Const_Fval: values = append(values, string(c.Fval.Fval)) // TODO: There might be other Node types like Node_IntList, Node_List, Node_BitString etc that // need to be checked if they are handled or not. default: conv.Unexpected(fmt.Sprintf("Processing %v statement: found %s node for A_Const Val", printNodeType(n), printNodeType(c))) } } default: conv.Unexpected(fmt.Sprintf("Processing %v statement: found %s node for ValuesList.Val", printNodeType(n), printNodeType(val))) } } default: conv.Unexpected(fmt.Sprintf("Processing %v statement: found %s in ValuesList", printNodeType(n), printNodeType(vals))) } // If some or all of vals failed to parse, then size of values will be // less than the number of columns, and the same will be caught as a // BadRow in ProcessDataRow. rows = append(rows, values) } return rows } func logStmtError(conv *internal.Conv, node interface{}, err error) { conv.Unexpected(fmt.Sprintf("Processing %v statement: %s", printNodeType(node), err)) conv.ErrorInStatement(printNodeType(node)) } func getString(node *pg_query.Node) (string, error) { switch n := node.GetNode().(type) { case *pg_query.Node_String_: return trimString(n.String_), nil default: return "", fmt.Errorf("node %v is a not String node", printNodeType(n)) } } // checkEmpty verifies that pkeys is empty and generates a warning if it isn't. // PostgreSQL explicitly forbids multiple primary keys. func checkEmpty(conv *internal.Conv, pkeys []schema.Key, stmtType string) { if len(pkeys) != 0 { conv.Unexpected(fmt.Sprintf("%s statement is adding a second primary key", stmtType)) } } // printNodeType returns string representation for the type of node. Trims // "pg_query." and "Node_" prefixes from pg_query.Node_* types. func printNodeType(node interface{}) string { return strings.TrimPrefix(strings.TrimPrefix(reflect.TypeOf(node).String(), "*pg_query."), "Node_") } func trimString(s *pg_query.String) string { str := s.Sval str = trimEscapeChars(str) return trimQuote(str) } func trimEscapeChars(s string) string { return strings.ReplaceAll(s, "\\n", "\n") } func trimQuote(s string) string { if len(s) > 0 && s[0] == '"' { s = s[1:] } if len(s) > 0 && s[len(s)-1] == '"' { s = s[:len(s)-1] } return s }