import_data/csv_schema.go (153 lines of code) (raw):

package import_data import ( "context" csv2 "encoding/csv" "fmt" "io" "os" "sort" "strconv" "strings" spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/parse" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/google/subcommands" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" ) type CsvSchema interface { CreateSchema(ctx context.Context, dialect string, sp *spanneraccessor.SpannerAccessorImpl) subcommands.ExitStatus } type CsvSchemaImpl struct { ProjectId string InstanceId string DbName string TableName string SchemaUri string CsvFieldDelimiter string } // ColumnDefinition represents the definition of a Spanner table column. type ColumnDefinition struct { Name string Type string // e.g., "INT64", "STRING(MAX)", "TIMESTAMP", "DATE" NotNull bool PkOrder int // defines the order in the PK for the table, 0 means absence. } type PrimaryKey struct { Name string PkOrder int // defines the order in the PK for the table, 0 means absence. } func (source *CsvSchemaImpl) CreateSchema(ctx context.Context, dialect string, sp *spanneraccessor.SpannerAccessorImpl) error { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", source.ProjectId, source.InstanceId, source.DbName) colDef, err := parseSchema(source.SchemaUri, rune(source.CsvFieldDelimiter[0])) if err != nil { logger.Log.Error(fmt.Sprintf("Unable to parse schema URI %v", err)) return err } dbExists, err := sp.TableExists(ctx, source.TableName) if err != nil { logger.Log.Error(fmt.Sprintf("Unable to check existing schema %v", err)) return err } if dbExists { logger.Log.Error(fmt.Sprintf("table %s exists ", source.TableName)) // if exists, verify table schema is same as passed // TODO: validate schema matches return nil } ddl := getCreateTableStmt(source.TableName, colDef, dialect) stmts := []string{ddl} req := &adminpb.UpdateDatabaseDdlRequest{ Database: dbURI, Statements: stmts, } op, err := sp.AdminClient.UpdateDatabaseDdl(ctx, req) if err != nil { return fmt.Errorf("can't build UpdateDatabaseDdlRequest: %w", parse.AnalyzeError(err, dbURI)) } if err := op.Wait(ctx); err != nil { return fmt.Errorf("UpdateDatabaseDdl call failed: %w", parse.AnalyzeError(err, dbURI)) } logger.Log.Info(fmt.Sprintf("Created table %v successfully\n", source.TableName)) return nil } func parseSchema(schemaUri string, delimiter rune) ([]ColumnDefinition, error) { schemaFile, err := os.Open(schemaUri) if err != nil { return nil, fmt.Errorf("error opening file: %v", err) } defer schemaFile.Close() reader := csv2.NewReader(schemaFile) reader.Comma = delimiter reader.TrimLeadingSpace = true var colDefs []ColumnDefinition for { record, err := reader.Read() if err == io.EOF { break } if err != nil { return nil, fmt.Errorf("error reading CSV record: %v", err) } if len(record) != 4 { return nil, fmt.Errorf("expected 4 columns, but got %d", len(record)) } pkOrder, err := strconv.Atoi(strings.TrimSpace(record[3])) if err != nil { fmt.Println("Error parsing schema file", err) return colDefs, err } colDef := ColumnDefinition{record[0], record[1], StringToBool(record[2]), pkOrder} colDefs = append(colDefs, colDef) } return colDefs, nil } func getCreateTableStmt(tableName string, colDef []ColumnDefinition, dialect string) string { var col, pk string pks := []PrimaryKey{} for _, cd := range colDef { s := printColumnDef(cd) if len(col) > 0 { s = "," + s } col = col + s if cd.PkOrder != 0 { pks = append(pks, PrimaryKey{cd.Name, cd.PkOrder}) } } sort.Slice(pks, func(i, j int) bool { return pks[i].PkOrder < pks[j].PkOrder }) for _, p := range pks { s := quote(p.Name) if len(pk) > 0 { s = "," + s } pk = pk + s } var stmt string if dialect == constants.DIALECT_POSTGRESQL { stmt = fmt.Sprintf("CREATE TABLE %s (\n%s PRIMARY KEY (%s)\n)", quote(tableName), col, pk) } stmt = fmt.Sprintf("CREATE TABLE %s (\n%s) PRIMARY KEY (%s)", quote(tableName), col, pk) logger.Log.Debug(fmt.Sprintf("create table cmd %s ==", stmt)) return stmt } func printColumnDef(c ColumnDefinition) string { s := fmt.Sprintf("%s %s", quote(c.Name), c.Type) if c.NotNull { s += " NOT NULL " } return s } func quote(s string) string { return "`" + s + "`" } func StringToBool(s string) bool { s = strings.ToLower(strings.TrimSpace(s)) // Normalize the string if s == "" { return false } boolVal, err := strconv.ParseBool(s) if err == nil { return boolVal } return false }