go/adbc/driver/snowflake/record_reader.go (609 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package snowflake import ( "bytes" "context" "encoding/hex" "encoding/json" "errors" "fmt" "io" "math" "strconv" "strings" "sync/atomic" "time" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/compute" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/snowflakedb/gosnowflake" "golang.org/x/sync/errgroup" ) const MetadataKeySnowflakeType = "SNOWFLAKE_TYPE" func identCol(_ context.Context, a arrow.Array) (arrow.Array, error) { a.Retain() return a, nil } type recordTransformer = func(context.Context, arrow.Record) (arrow.Record, error) type colTransformer = func(context.Context, arrow.Array) (arrow.Array, error) func getRecTransformer(sc *arrow.Schema, tr []colTransformer) recordTransformer { return func(ctx context.Context, r arrow.Record) (arrow.Record, error) { if len(tr) != int(r.NumCols()) { return nil, adbc.Error{ Msg: "mismatch in record cols and transformers", Code: adbc.StatusInvalidState, } } var ( err error cols = make([]arrow.Array, r.NumCols()) ) for i, col := range r.Columns() { if cols[i], err = tr[i](ctx, col); err != nil { return nil, errToAdbcErr(adbc.StatusInternal, err) } defer cols[i].Release() } return array.NewRecord(sc, cols, r.NumRows()), nil } } func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, recordTransformer) { loc, types := ld.Location(), ld.RowTypes() fields := make([]arrow.Field, len(sc.Fields())) transformers := make([]func(context.Context, arrow.Array) (arrow.Array, error), len(sc.Fields())) for i, f := range sc.Fields() { srcMeta := types[i] switch strings.ToUpper(srcMeta.Type) { case "FIXED": switch f.Type.ID() { case arrow.DECIMAL, arrow.DECIMAL256: if useHighPrecision { transformers[i] = identCol } else { if srcMeta.Scale == 0 { f.Type = arrow.PrimitiveTypes.Int64 } else { f.Type = arrow.PrimitiveTypes.Float64 } dt := f.Type transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { return compute.CastArray(ctx, a, compute.UnsafeCastOptions(dt)) } } default: if useHighPrecision { dt := &arrow.Decimal128Type{ Precision: int32(srcMeta.Precision), Scale: int32(srcMeta.Scale), } f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { return integerToDecimal128(ctx, a, dt) } } else { if srcMeta.Scale != 0 { f.Type = arrow.PrimitiveTypes.Float64 // For precisions of 16, 17 and 18, a conversion from int64 to float64 fails with an error // So for these precisions, we instead convert first to a decimal128 and then to a float64. if srcMeta.Precision > 15 && srcMeta.Precision < 19 { transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { result, err := integerToDecimal128(ctx, a, &arrow.Decimal128Type{ Precision: int32(srcMeta.Precision), Scale: int32(srcMeta.Scale), }) if err != nil { return nil, err } defer result.Release() return compute.CastArray(ctx, result, compute.UnsafeCastOptions(f.Type)) } } else { // For precisions less than 16, we can simply scale the integer value appropriately transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true}, &compute.ArrayDatum{Value: a.Data()}, compute.NewDatum(math.Pow10(int(srcMeta.Scale)))) if err != nil { return nil, err } defer result.Release() return result.(*compute.ArrayDatum).MakeArray(), nil } } } else { f.Type = arrow.PrimitiveTypes.Int64 transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { return compute.CastArray(ctx, a, compute.SafeCastOptions(arrow.PrimitiveTypes.Int64)) } } } } case "TIME": var dt arrow.DataType if srcMeta.Scale < 6 { dt = &arrow.Time32Type{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} } else { dt = &arrow.Time64Type{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} } f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { return compute.CastArray(ctx, a, compute.SafeCastOptions(dt)) } case "TIMESTAMP_NTZ": dt := &arrow.TimestampType{Unit: arrow.TimeUnit(srcMeta.Scale / 3)} f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { if a.DataType().ID() != arrow.STRUCT { return compute.CastArray(ctx, a, compute.SafeCastOptions(dt)) } pool := compute.GetAllocator(ctx) tb := array.NewTimestampBuilder(pool, dt) defer tb.Release() structData := a.(*array.Struct) epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() for i := 0; i < a.Len(); i++ { if a.IsNull(i) { tb.AppendNull() continue } v, err := arrow.TimestampFromTime(time.Unix(epoch[i], int64(fraction[i])), dt.TimeUnit()) if err != nil { return nil, err } tb.Append(v) } return tb.NewArray(), nil } case "TIMESTAMP_LTZ": dt := &arrow.TimestampType{Unit: arrow.TimeUnit(srcMeta.Scale) / 3, TimeZone: loc.String()} f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { pool := compute.GetAllocator(ctx) tb := array.NewTimestampBuilder(pool, dt) defer tb.Release() if a.DataType().ID() == arrow.STRUCT { structData := a.(*array.Struct) epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() for i := 0; i < a.Len(); i++ { if a.IsNull(i) { tb.AppendNull() continue } v, err := arrow.TimestampFromTime(time.Unix(epoch[i], int64(fraction[i])), dt.TimeUnit()) if err != nil { return nil, err } tb.Append(v) } } else { for i, t := range a.(*array.Int64).Int64Values() { if a.IsNull(i) { tb.AppendNull() continue } tb.Append(arrow.Timestamp(t)) } } return tb.NewArray(), nil } case "TIMESTAMP_TZ": // we convert each value to UTC since we have timezone information // with the data that lets us do so. dt := &arrow.TimestampType{TimeZone: "UTC", Unit: arrow.TimeUnit(srcMeta.Scale / 3)} f.Type = dt transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) { pool := compute.GetAllocator(ctx) tb := array.NewTimestampBuilder(pool, dt) defer tb.Release() structData := a.(*array.Struct) if structData.NumField() == 2 { epoch := structData.Field(0).(*array.Int64).Int64Values() tzoffset := structData.Field(1).(*array.Int32).Int32Values() for i := 0; i < a.Len(); i++ { if a.IsNull(i) { tb.AppendNull() continue } loc := gosnowflake.Location(int(tzoffset[i]) - 1440) v, err := arrow.TimestampFromTime(time.Unix(epoch[i], 0).In(loc), dt.Unit) if err != nil { return nil, err } tb.Append(v) } } else { epoch := structData.Field(0).(*array.Int64).Int64Values() fraction := structData.Field(1).(*array.Int32).Int32Values() tzoffset := structData.Field(2).(*array.Int32).Int32Values() for i := 0; i < a.Len(); i++ { if a.IsNull(i) { tb.AppendNull() continue } loc := gosnowflake.Location(int(tzoffset[i]) - 1440) v, err := arrow.TimestampFromTime(time.Unix(epoch[i], int64(fraction[i])).In(loc), dt.Unit) if err != nil { return nil, err } tb.Append(v) } } return tb.NewArray(), nil } default: transformers[i] = identCol } fields[i] = f } meta := sc.Metadata() out := arrow.NewSchema(fields, &meta) return out, getRecTransformer(out, transformers) } func integerToDecimal128(ctx context.Context, a arrow.Array, dt *arrow.Decimal128Type) (arrow.Array, error) { // We can't do a cast directly into the destination type because the numbers we get from Snowflake // are scaled integers. So not only would the cast produce the wrong value, it also risks producing // an error of precisions which e.g. can't hold every int64. To work around these problems, we instead // cast into a decimal type of a precision and scale which we know will hold all values and won't // require scaling, We then substitute the type on this array with the actual return type. dt0 := &arrow.Decimal128Type{ Precision: int32(20), Scale: int32(0), } result, err := compute.CastArray(ctx, a, compute.SafeCastOptions(dt0)) if err != nil { return nil, err } data := result.Data() result.Data().Reset(dt, data.Len(), data.Buffers(), data.Children(), data.NullN(), data.Offset()) return result, err } func rowTypesToArrowSchema(_ context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) { var loc *time.Location metadata := ld.RowTypes() fields := make([]arrow.Field, len(metadata)) for i, srcMeta := range metadata { fields[i] = arrow.Field{ Name: srcMeta.Name, Nullable: srcMeta.Nullable, Metadata: arrow.MetadataFrom(map[string]string{ MetadataKeySnowflakeType: srcMeta.Type, }), } switch srcMeta.Type { case "fixed": if useHighPrecision { fields[i].Type = &arrow.Decimal128Type{ Precision: int32(srcMeta.Precision), Scale: int32(srcMeta.Scale), } } else { fields[i].Type = arrow.PrimitiveTypes.Int64 } case "real": fields[i].Type = arrow.PrimitiveTypes.Float64 case "date": fields[i].Type = arrow.PrimitiveTypes.Date32 case "time": fields[i].Type = arrow.FixedWidthTypes.Time64ns case "timestamp_ntz", "timestamp_tz": fields[i].Type = arrow.FixedWidthTypes.Timestamp_ns case "timestamp_ltz": if loc == nil { loc = ld.Location() } fields[i].Type = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()} case "binary": fields[i].Type = arrow.BinaryTypes.Binary default: fields[i].Type = arrow.BinaryTypes.String } } return arrow.NewSchema(fields, nil), nil } func extractTimestamp(src *string) (sec, nsec int64, err error) { s, ms, hasFraction := strings.Cut(*src, ".") sec, err = strconv.ParseInt(s, 10, 64) if err != nil { return } if !hasFraction { return } nsec, err = strconv.ParseInt(ms+strings.Repeat("0", 9-len(ms)), 10, 64) return } func jsonDataToArrow(_ context.Context, bldr *array.RecordBuilder, rawData [][]*string) (arrow.Record, error) { fieldBuilders := bldr.Fields() for _, rec := range rawData { for i, col := range rec { field := fieldBuilders[i] if col == nil { field.AppendNull() continue } switch fb := field.(type) { case *array.Time64Builder: sec, nsec, err := extractTimestamp(col) if err != nil { return nil, err } fb.Append(arrow.Time64(sec*1e9 + nsec)) case *array.TimestampBuilder: snowflakeType, ok := bldr.Schema().Field(i).Metadata.GetValue(MetadataKeySnowflakeType) if !ok { return nil, errToAdbcErr( adbc.StatusInvalidData, fmt.Errorf("key %s not found in metadata for field %s", MetadataKeySnowflakeType, bldr.Schema().Field(i).Name), ) } if snowflakeType == "timestamp_tz" { // "timestamp_tz" should be value + offset separated by space tm := strings.Split(*col, " ") if len(tm) != 2 { return nil, adbc.Error{ Msg: "invalid TIMESTAMP_TZ data. value doesn't consist of two numeric values separated by a space: " + *col, SqlState: [5]byte{'2', '2', '0', '0', '7'}, VendorCode: 268000, Code: adbc.StatusInvalidData, } } sec, nsec, err := extractTimestamp(&tm[0]) if err != nil { return nil, err } offset, err := strconv.ParseInt(tm[1], 10, 64) if err != nil { return nil, adbc.Error{ Msg: "invalid TIMESTAMP_TZ data. offset value is not an integer: " + tm[1], SqlState: [5]byte{'2', '2', '0', '0', '7'}, VendorCode: 268000, Code: adbc.StatusInvalidData, } } loc := gosnowflake.Location(int(offset) - 1440) tt := time.Unix(sec, nsec).In(loc) ts, err := arrow.TimestampFromTime(tt, arrow.Nanosecond) if err != nil { return nil, err } fb.Append(ts) break } // otherwise timestamp_ntz or timestamp_ltz, which have the same physical representation sec, nsec, err := extractTimestamp(col) if err != nil { return nil, err } fb.Append(arrow.Timestamp(sec*1e9 + nsec)) case *array.BinaryBuilder: b, err := hex.DecodeString(*col) if err != nil { return nil, adbc.Error{ Msg: err.Error(), VendorCode: 268002, SqlState: [5]byte{'2', '2', '0', '0', '3'}, Code: adbc.StatusInvalidData, } } fb.Append(b) default: if err := fb.AppendValueFromString(*col); err != nil { return nil, err } } } } return bldr.NewRecord(), nil } type reader struct { refCount int64 schema *arrow.Schema chs []chan arrow.Record curChIndex int rec arrow.Record err error cancelFn context.CancelFunc } func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool) (array.RecordReader, error) { batches, err := ld.GetBatches() if err != nil { return nil, errToAdbcErr(adbc.StatusInternal, err) } // if the first chunk was JSON, that means this was a metadata query which // is only returning JSON data rather than Arrow rawData := ld.JSONData() if len(rawData) > 0 { // construct an Arrow schema based on reading the JSON metadata description of the // result type schema schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision) if err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } if ld.TotalRows() == 0 { return array.NewRecordReader(schema, []arrow.Record{}) } bldr := array.NewRecordBuilder(alloc, schema) defer bldr.Release() rec, err := jsonDataToArrow(ctx, bldr, rawData) if err != nil { return nil, err } defer rec.Release() results := []arrow.Record{rec} for _, b := range batches { rdr, err := b.GetStream(ctx) if err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } // the "JSON" data returned isn't valid JSON. Instead it is a list of // comma-delimited JSON lists containing every value as a string, except // for a JSON null to represent nulls. Thus we can't just use the existing // JSON parsing code in Arrow. data, err := io.ReadAll(rdr) rdrErr := rdr.Close() if err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } else if rdrErr != nil { return nil, rdrErr } if cap(rawData) >= int(b.NumRows()) { rawData = rawData[:b.NumRows()] } else { rawData = make([][]*string, b.NumRows()) } bldr.Reserve(int(b.NumRows())) // we grab the entire JSON message and create a bytes reader offset, buf := int64(0), bytes.NewReader(data) for i := 0; i < int(b.NumRows()); i++ { // we construct a decoder from the bytes.Reader to read the next JSON list // of columns (one row) from the input dec := json.NewDecoder(buf) if err = dec.Decode(&rawData[i]); err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } // dec.InputOffset() now represents the index of the ',' so we skip the comma offset += dec.InputOffset() + 1 // then seek the buffer to that spot. we have to seek based on the start // because json.Decoder can read from the buffer more than is necessary to // process the JSON data. if _, err = buf.Seek(offset, 0); err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } } // now that we have our [][]*string of JSON data, we can pass it to get converted // to an Arrow record batch and appended to our slice of batches rec, err := jsonDataToArrow(ctx, bldr, rawData) if err != nil { return nil, err } defer rec.Release() results = append(results, rec) } return array.NewRecordReader(schema, results) } ch := make(chan arrow.Record, bufferSize) group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc)) ctx, cancelFn := context.WithCancel(ctx) group.SetLimit(prefetchConcurrency) defer func() { if err != nil { close(ch) cancelFn() } }() chs := make([]chan arrow.Record, len(batches)) rdr := &reader{ refCount: 1, chs: chs, err: nil, cancelFn: cancelFn, } if len(batches) == 0 { schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision) if err != nil { return nil, err } rdr.schema, _ = getTransformer(schema, ld, useHighPrecision) return rdr, nil } r, err := batches[0].GetStream(ctx) if err != nil { return nil, errToAdbcErr(adbc.StatusIO, err) } rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc)) if err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInvalidState, } } var recTransform recordTransformer rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision) group.Go(func() (err error) { defer rr.Release() defer func() { err = errors.Join(err, r.Close()) }() if len(batches) > 1 { defer close(ch) } for rr.Next() && ctx.Err() == nil { rec := rr.Record() rec, err = recTransform(ctx, rec) if err != nil { return err } ch <- rec } return rr.Err() }) chs[0] = ch lastChannelIndex := len(chs) - 1 go func() { for i, b := range batches[1:] { batch, batchIdx := b, i+1 chs[batchIdx] = make(chan arrow.Record, bufferSize) group.Go(func() (err error) { // close channels (except the last) so that Next can move on to the next channel properly if batchIdx != lastChannelIndex { defer close(chs[batchIdx]) } rdr, err := batch.GetStream(ctx) if err != nil { return err } defer func() { err = errors.Join(err, rdr.Close()) }() rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc)) if err != nil { return err } defer rr.Release() for rr.Next() && ctx.Err() == nil { rec := rr.Record() rec, err = recTransform(ctx, rec) if err != nil { return err } chs[batchIdx] <- rec } return rr.Err() }) } // place this here so that we always clean up, but they can't be in a // separate goroutine. Otherwise we'll have a race condition between // the call to wait and the calls to group.Go to kick off the jobs // to perform the pre-fetching (GH-1283). rdr.err = group.Wait() // don't close the last channel until after the group is finished, // so that Next() can only return after reader.err may have been set close(chs[lastChannelIndex]) }() return rdr, nil } func (r *reader) Schema() *arrow.Schema { return r.schema } func (r *reader) Record() arrow.Record { return r.rec } func (r *reader) Err() error { return r.err } func (r *reader) Next() bool { if r.rec != nil { r.rec.Release() r.rec = nil } if r.curChIndex >= len(r.chs) { return false } var ok bool for r.curChIndex < len(r.chs) { if r.rec, ok = <-r.chs[r.curChIndex]; ok { break } r.curChIndex++ } return r.rec != nil } func (r *reader) Retain() { atomic.AddInt64(&r.refCount, 1) } func (r *reader) Release() { if atomic.AddInt64(&r.refCount, -1) == 0 { if r.rec != nil { r.rec.Release() } r.cancelFn() for _, ch := range r.chs { for rec := range ch { rec.Release() } } } }