sources/csv/data.go (521 lines of code) (raw):
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package csv
import (
csvReader "encoding/csv"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"math/big"
"os"
"strconv"
"strings"
"time"
"cloud.google.com/go/civil"
"cloud.google.com/go/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/logger"
"github.com/GoogleCloudPlatform/spanner-migration-tool/profiles"
"github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl"
)
type CsvInterface interface {
GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []utils.ManifestTable, err error)
SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) error
ProcessCSV(conv *internal.Conv, tables []utils.ManifestTable, nullStr string, delimiter rune) error
ProcessSingleCSV(conv *internal.Conv, tableName string, columnNames []string, colDefs map[string]ddl.ColumnDef, filePath string, nullStr string, delimiter rune) error
}
type CsvImpl struct{}
// GetCSVFiles finds the appropriate files paths and downloads gcs files in any.
func (c *CsvImpl) GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []utils.ManifestTable, err error) {
// If manifest file not provided, we assume the csvs exist in the same directory
// in table_name.csv format.
if sourceProfile.Csv.Manifest == "" {
fmt.Println("Manifest file not provided, checking for files named `[table_name].csv` in current working directory...")
for _, schema := range conv.SpSchema {
tables = append(tables, utils.ManifestTable{Table_name: schema.Name, File_patterns: []string{fmt.Sprintf("%s.csv", schema.Name)}})
}
} else {
fmt.Println("Manifest file provided, reading csv file paths...")
// Read paths provided in manifest.
tables, err = loadManifest(conv, sourceProfile.Csv.Manifest)
if err != nil {
return nil, err
}
}
// Download gcs files if any.
tables, err = utils.PreloadGCSFiles(tables)
if err != nil {
return nil, fmt.Errorf("gcs file download error: %v", err)
}
return tables, nil
}
// loadManifest reads the manifest file and unmarshalls it into a list of Table struct.
// It also performs certain checks on the manifest.
func loadManifest(conv *internal.Conv, manifestFile string) ([]utils.ManifestTable, error) {
manifest, err := ioutil.ReadFile(manifestFile)
if err != nil {
return nil, fmt.Errorf("can't read manifest file due to: %v", err)
}
tables := []utils.ManifestTable{}
err = json.Unmarshal(manifest, &tables)
if err != nil {
return nil, fmt.Errorf("unable to unmarshall json due to: %v", err)
}
err = VerifyManifest(conv, tables)
if err != nil {
return nil, fmt.Errorf("manifest is incomplete: %v", err)
}
return tables, nil
}
// VerifyManifest performs certain prechecks on the structure of the manifest while populating the conv with
// the ddl types. Also checks on valid file paths and empty CSVs are handled as conv.Unexpected errors later during processing.
func VerifyManifest(conv *internal.Conv, tables []utils.ManifestTable) error {
if len(tables) == 0 {
return fmt.Errorf("no tables found")
}
missing := []string{}
for _, v := range conv.SrcSchema {
found := false
for _, table := range tables {
if v.Name == table.Table_name {
found = true
break
}
}
if !found {
missing = append(missing, v.Name)
}
}
if len(missing) > 0 {
fmt.Printf("WARNING: did not find manifest entries for tables [ %s ], ignoring and proceeding...\n", strings.Join(missing, ", "))
conv.Unexpected(fmt.Sprintf("did not find manifest entries for tables [ %s ]", strings.Join(missing, ", ")))
}
for i, table := range tables {
name := table.Table_name
if name == "" {
return fmt.Errorf("table number %d (0-indexed) does not have a name", i)
}
_, err := internal.GetTableIdFromSrcName(conv.SrcSchema, name)
if err != nil {
return fmt.Errorf("table %s provided in manifest does not exist in spanner", name)
}
if len(table.File_patterns) == 0 {
return fmt.Errorf("no file path provided for table %s", name)
}
}
return nil
}
// SetRowStats calculates the number of rows per table.
func (c *CsvImpl) SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) error {
for _, table := range tables {
for _, filePath := range table.File_patterns {
csvFile, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("can't read csv file: %s due to: %v", filePath, err)
}
r := csvReader.NewReader(csvFile)
r.Comma = delimiter
tableId, err := internal.GetTableIdFromSpName(conv.SpSchema, table.Table_name)
if err != nil {
return fmt.Errorf("table Id not found for spanner table %v", table.Table_name)
}
colNames := []string{}
for _, colIds := range conv.SpSchema[tableId].ColIds {
colNames = append(colNames, conv.SpSchema[tableId].ColDefs[colIds].Name)
}
count, err := getCSVDataRowCount(r, colNames)
if err != nil {
return fmt.Errorf("error reading file %s for table %s: %v", filePath, table.Table_name, err)
}
if count == 0 {
conv.Unexpected(fmt.Sprintf("error processing table %s: file %s is empty.", table.Table_name, filePath))
continue
}
conv.Stats.Rows[table.Table_name] += count
}
}
return nil
}
// getCSVDataRowCount returns the number of data rows in the CSV file. This excludes the headers if present.
func getCSVDataRowCount(r *csvReader.Reader, colNames []string) (int64, error) {
count := int64(0)
srcCols, err := r.Read()
if err == io.EOF {
return count, nil
}
if err != nil {
return count, fmt.Errorf("can't read csv headers for col names due to: %v", err)
}
if len(srcCols) != len(colNames) {
return 0, fmt.Errorf("found %d columns in csv, expected %d as per Spanner schema", len(srcCols), len(colNames))
}
// If the row read was not a header, increase count.
if !utils.CheckEqualSets(srcCols, colNames) {
count += 1
}
for {
_, err := r.Read()
if err == io.EOF {
break
}
if err != nil {
return 0, fmt.Errorf("can't read row")
}
count++
}
return count, nil
}
// ProcessCSV writes data across the tables provided in the manifest file. Each table's data can be provided
// across multiple CSV files hence, the manifest accepts a list of file paths in the input.
func (c *CsvImpl) ProcessCSV(conv *internal.Conv, tables []utils.ManifestTable, nullStr string, delimiter rune) error {
tableIds := ddl.GetSortedTableIdsBySpName(conv.SpSchema)
nameToFiles := map[string][]string{}
for _, table := range tables {
nameToFiles[table.Table_name] = table.File_patterns
}
orderedTables := []utils.ManifestTable{}
for _, id := range tableIds {
orderedTables = append(orderedTables, utils.ManifestTable{conv.SpSchema[id].Name, nameToFiles[conv.SpSchema[id].Name]})
}
for _, table := range orderedTables {
for _, filePath := range table.File_patterns {
// Default column order is same as in Spanner schema.
tableId, err := internal.GetTableIdFromSpName(conv.SpSchema, table.Table_name)
if err != nil {
return fmt.Errorf("table Id not found for spanner table %v", table.Table_name)
}
colNames := []string{}
for _, v := range conv.SpSchema[tableId].ColIds {
colNames = append(colNames, conv.SpSchema[tableId].ColDefs[v].Name)
}
colDefs := conv.SpSchema[tableId].ColDefs
err = c.ProcessSingleCSV(conv, table.Table_name, colNames, colDefs,
filePath, nullStr, delimiter)
if err != nil {
return err
}
}
if conv.DataFlush != nil {
conv.DataFlush()
}
}
return nil
}
func (c *CsvImpl) ProcessSingleCSV(conv *internal.Conv, tableName string,
columnNames []string, colDefs map[string]ddl.ColumnDef, filePath string,
nullStr string, delimiter rune) error {
csvFile, err := os.Open(filePath)
if err != nil {
return fmt.Errorf(fmt.Sprintf("can't read csv file: %s due to: %v\n", filePath, err))
}
r := csvReader.NewReader(csvFile)
r.Comma = delimiter
srcCols, err := r.Read()
if err == io.EOF {
logger.Log.Error(fmt.Sprintf("error processing table %s: file %s is empty.", tableName, filePath))
return err
}
if err != nil {
return fmt.Errorf("can't read row for %s due to: %v", filePath, err)
}
// If first row is some permutation of Spanner schema columns, we assume the first row is headers.
if utils.CheckEqualSets(srcCols, columnNames) {
columnNames = srcCols
} else {
// Write the first row since it was not a column header.
processDataRow(conv, nullStr, tableName, columnNames, colDefs, srcCols)
}
for {
values, err := r.Read()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("can't read row for %s due to: %v", filePath, err)
}
processDataRow(conv, nullStr, tableName, columnNames, colDefs, values)
}
return nil
}
// processDataRow converts a row into go data types as per the client libs.
func processDataRow(conv *internal.Conv, nullStr, tableName string,
srcCols []string, colDefs map[string]ddl.ColumnDef, values []string) {
// Pass nullStr from source-profile.
cvtCols, cvtVals, err := convertData(conv.SpDialect, nullStr, srcCols, colDefs, values)
if err != nil {
logger.Log.Error(fmt.Sprintf("Error while converting data: %s\n", err))
} else {
conv.WriteRow(tableName, tableName, cvtCols, cvtVals)
}
}
// convertData currently only supports scalar data types.
func convertData(dialect, nullStr string, srcCols []string,
colDefs map[string]ddl.ColumnDef, values []string) (
[]string, []interface{}, error) {
var v []interface{}
var cvtCols []string
for i, val := range values {
if val == nullStr {
continue
}
colName := srcCols[i]
colId, err := internal.GetColIdFromSpName(colDefs, colName)
if err != nil {
return cvtCols, v, fmt.Errorf("Unable to get colId from SpName for column %s ", colName)
}
spColDef := colDefs[colId]
var x interface{}
if spColDef.T.IsArray {
x, err = convArray(spColDef.T, val)
} else {
x, err = convScalar(dialect, spColDef.T, val)
}
if err != nil {
return nil, nil, err
}
v = append(v, x)
cvtCols = append(cvtCols, colName)
}
return cvtCols, v, nil
}
func convArray(spannerType ddl.Type, val string) (interface{}, error) {
val = strings.TrimSpace(val)
// Handle empty array. Note that we use an empty NullString array
// for all Spanner array types since this will be converted to the
// appropriate type by the Spanner client.
if val == "{}" || val == "[]" {
return []spanner.NullString{}, nil
}
braces := val[:1] + val[len(val)-1:]
if braces != "{}" && braces != "[]" {
return []interface{}{}, fmt.Errorf("unrecognized data format for array: expected {v1, v2, ...} or [v1, v2, ...]")
}
a := strings.Split(val[1:len(val)-1], ",")
// The Spanner client for go does not accept []interface{} for arrays.
// Instead it only accepts slices of a specific type e.g. []int64, []string.
// Hence we have to do the following case analysis.
switch spannerType.Name {
case ddl.Bool:
var r []spanner.NullBool
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullBool{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullBool{}, err
}
b, err := convBool(s)
if err != nil {
return []spanner.NullBool{}, err
}
r = append(r, spanner.NullBool{Bool: b, Valid: true})
}
return r, nil
case ddl.Bytes:
var r [][]byte
for _, s := range a {
if s == "NULL" {
r = append(r, nil)
continue
}
s, err := processQuote(s)
if err != nil {
return [][]byte{}, err
}
b, err := convBytes(s)
if err != nil {
return [][]byte{}, err
}
r = append(r, b)
}
return r, nil
case ddl.Date:
var r []spanner.NullDate
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullDate{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullDate{}, err
}
date, err := convDate(s)
if err != nil {
return []spanner.NullDate{}, err
}
r = append(r, spanner.NullDate{Date: date, Valid: true})
}
return r, nil
case ddl.Float32:
var r []spanner.NullFloat32
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullFloat32{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullFloat32{}, err
}
f, err := convFloat32(s)
if err != nil {
return []spanner.NullFloat32{}, err
}
r = append(r, spanner.NullFloat32{Float32: f, Valid: true})
}
return r, nil
case ddl.Float64:
var r []spanner.NullFloat64
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullFloat64{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullFloat64{}, err
}
f, err := convFloat64(s)
if err != nil {
return []spanner.NullFloat64{}, err
}
r = append(r, spanner.NullFloat64{Float64: f, Valid: true})
}
return r, nil
case ddl.Numeric:
var r []spanner.NullNumeric
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullNumeric{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullNumeric{}, err
}
n, err := convNumeric(s)
if err != nil {
return []spanner.NullNumeric{}, err
}
r = append(r, spanner.NullNumeric{Numeric: n, Valid: true})
}
return r, nil
case ddl.Int64:
var r []spanner.NullInt64
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullInt64{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullInt64{}, err
}
i, err := convInt64(s)
if err != nil {
return r, err
}
r = append(r, spanner.NullInt64{Int64: i, Valid: true})
}
return r, nil
case ddl.String:
var r []spanner.NullString
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullString{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullString{}, err
}
r = append(r, spanner.NullString{StringVal: s, Valid: true})
}
return r, nil
case ddl.Timestamp:
var r []spanner.NullTime
for _, s := range a {
if s == "NULL" {
r = append(r, spanner.NullTime{Valid: false})
continue
}
s, err := processQuote(s)
if err != nil {
return []spanner.NullTime{}, err
}
t, err := convTimestamp(s)
if err != nil {
return []spanner.NullTime{}, err
}
r = append(r, spanner.NullTime{Time: t, Valid: true})
}
return r, nil
}
return []interface{}{}, fmt.Errorf("array type conversion not implemented for type []%v", spannerType.Name)
}
func convScalar(dialect string, spannerType ddl.Type, val string) (interface{}, error) {
switch spannerType.Name {
case ddl.Bool:
return convBool(val)
case ddl.Bytes:
return convBytes(val)
case ddl.Date:
return convDate(val)
case ddl.Float32:
return convFloat32(val)
case ddl.Float64:
return convFloat64(val)
case ddl.Int64:
return convInt64(val)
case ddl.Numeric:
if dialect == constants.DIALECT_POSTGRESQL {
return spanner.PGNumeric{Numeric: val, Valid: true}, nil
}
return convNumeric(val)
case ddl.String:
return val, nil
case ddl.Timestamp:
return convTimestamp(val)
case ddl.JSON:
return val, nil
default:
return val, fmt.Errorf("data conversion not implemented for type %v", spannerType)
}
}
func convBool(val string) (bool, error) {
b, err := strconv.ParseBool(val)
if err != nil {
return b, fmt.Errorf("can't convert to bool: %w", err)
}
return b, err
}
func convBytes(val string) ([]byte, error) {
// convert a string to a byte slice.
b := []byte(val)
return b, nil
}
func convDate(val string) (civil.Date, error) {
d, err := civil.ParseDate(val)
if err != nil {
return d, fmt.Errorf("can't convert to date: %w", err)
}
return d, err
}
func convFloat32(val string) (float32, error) {
f, err := strconv.ParseFloat(val, 32)
if err != nil {
return float32(f), fmt.Errorf("can't convert to float32: %w", err)
}
return float32(f), err
}
func convFloat64(val string) (float64, error) {
f, err := strconv.ParseFloat(val, 64)
if err != nil {
return f, fmt.Errorf("can't convert to float64: %w", err)
}
return f, err
}
func convInt64(val string) (int64, error) {
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return i, fmt.Errorf("can't convert to int64: %w", err)
}
return i, err
}
// convNumeric maps a source database string value (representing a numeric)
// into a string representing a valid Spanner numeric.
func convNumeric(val string) (big.Rat, error) {
r := new(big.Rat)
if _, ok := r.SetString(val); !ok {
return big.Rat{}, fmt.Errorf("can't convert %q to big.Rat", val)
}
return *r, nil
}
func convTimestamp(val string) (t time.Time, err error) {
t, err = time.Parse("2006-01-02 15:04:05", val)
if err != nil {
return t, fmt.Errorf("can't convert to timestamp: %s", val)
}
return t, err
}
func processQuote(s string) (string, error) {
if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' {
return strconv.Unquote(s)
}
return s, nil
}