func()

in backend/helpers/e2ehelper/data_flow_tester.go [298:401]


func (t *DataFlowTester) CreateSnapshot(dst schema.Tabler, opts TableOptions) {
	location, _ := time.LoadLocation(`UTC`)

	targetFields := t.resolveTargetFields(dst, opts)
	pkColumnNames, err := dal.GetPrimarykeyColumnNames(t.Dal, dst)
	if err != nil {
		panic(err)
	}
	for i := 0; i < len(pkColumnNames); i++ {
		group := strings.Split(pkColumnNames[i], ".")
		if len(group) > 1 {
			pkColumnNames[i] = group[len(group)-1]
		}
	}
	allFields := append(pkColumnNames, targetFields...)
	allFields = utils.StringsUniq(allFields)
	dbCursor, err := t.Dal.Cursor(
		dal.Select(strings.Join(allFields, `,`)),
		dal.From(dst.TableName()),
		dal.Orderby(strings.Join(pkColumnNames, `,`)),
	)
	if err != nil {
		panic(errors.Default.Wrap(err, fmt.Sprintf("unable to run select query on table %s", dst.TableName())))
	}

	columns, err := errors.Convert01(dbCursor.Columns())
	if err != nil {
		panic(errors.Default.Wrap(err, fmt.Sprintf("unable to get columns from table %s", dst.TableName())))
	}
	csvWriter, _ := pluginhelper.NewCsvFileWriter(opts.CSVRelPath, columns)
	defer csvWriter.Close()

	// define how to scan value
	columnTypes, _ := dbCursor.ColumnTypes()
	forScanValues := make([]interface{}, len(allFields))
	for i, columnType := range columnTypes {
		if columnType.ScanType().Name() == `Time` || columnType.ScanType().Name() == `NullTime` {
			forScanValues[i] = new(sql.NullTime)
		} else if columnType.ScanType().Name() == `bool` {
			forScanValues[i] = new(bool)
		} else if columnType.ScanType().Name() == `RawBytes` {
			forScanValues[i] = new(sql.NullString)
		} else if columnType.ScanType().Name() == `NullInt64` {
			forScanValues[i] = new(sql.NullInt64)
		} else {
			forScanValues[i] = new(sql.NullString)
		}
	}

	for dbCursor.Next() {
		err = errors.Convert(dbCursor.Scan(forScanValues...))
		if err != nil {
			panic(errors.Default.Wrap(err, fmt.Sprintf("unable to scan row on table %s: %v", dst.TableName(), err)))
		}
		values := make([]string, len(allFields))
		for i := range forScanValues {
			switch forScanValues[i].(type) {
			case *sql.NullTime:
				value := forScanValues[i].(*sql.NullTime)
				if value.Valid {
					values[i] = value.Time.In(location).Format("2006-01-02T15:04:05.000-07:00")
				} else {
					if opts.Nullable {
						values[i] = "NULL"
					} else {
						values[i] = ""
					}
				}
			case *bool:
				if *forScanValues[i].(*bool) {
					values[i] = `1`
				} else {
					values[i] = `0`
				}
			case *sql.NullString:
				value := *forScanValues[i].(*sql.NullString)
				if value.Valid {
					values[i] = value.String
				} else {
					if opts.Nullable {
						values[i] = "NULL"
					} else {
						values[i] = ""
					}
				}
			case *sql.NullInt64:
				value := *forScanValues[i].(*sql.NullInt64)
				if value.Valid {
					values[i] = strconv.FormatInt(value.Int64, 10)
				} else {
					if opts.Nullable {
						values[i] = "NULL"
					} else {
						values[i] = ""
					}
				}
			case *string:
				values[i] = fmt.Sprint(*forScanValues[i].(*string))
			}
		}
		csvWriter.Write(values)
	}
	fmt.Printf("created CSV file: %s\n", opts.CSVRelPath)
}