in go/adbc/driver/snowflake/record_reader.go [472:693]
func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool) (array.RecordReader, error) {
batches, err := ld.GetBatches()
if err != nil {
return nil, errToAdbcErr(adbc.StatusInternal, err)
}
// if the first chunk was JSON, that means this was a metadata query which
// is only returning JSON data rather than Arrow
rawData := ld.JSONData()
if len(rawData) > 0 {
// construct an Arrow schema based on reading the JSON metadata description of the
// result type schema
schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}
if ld.TotalRows() == 0 {
return array.NewRecordReader(schema, []arrow.Record{})
}
bldr := array.NewRecordBuilder(alloc, schema)
defer bldr.Release()
rec, err := jsonDataToArrow(ctx, bldr, rawData)
if err != nil {
return nil, err
}
defer rec.Release()
results := []arrow.Record{rec}
for _, b := range batches {
rdr, err := b.GetStream(ctx)
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}
// the "JSON" data returned isn't valid JSON. Instead it is a list of
// comma-delimited JSON lists containing every value as a string, except
// for a JSON null to represent nulls. Thus we can't just use the existing
// JSON parsing code in Arrow.
data, err := io.ReadAll(rdr)
rdrErr := rdr.Close()
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
} else if rdrErr != nil {
return nil, rdrErr
}
if cap(rawData) >= int(b.NumRows()) {
rawData = rawData[:b.NumRows()]
} else {
rawData = make([][]*string, b.NumRows())
}
bldr.Reserve(int(b.NumRows()))
// we grab the entire JSON message and create a bytes reader
offset, buf := int64(0), bytes.NewReader(data)
for i := 0; i < int(b.NumRows()); i++ {
// we construct a decoder from the bytes.Reader to read the next JSON list
// of columns (one row) from the input
dec := json.NewDecoder(buf)
if err = dec.Decode(&rawData[i]); err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}
// dec.InputOffset() now represents the index of the ',' so we skip the comma
offset += dec.InputOffset() + 1
// then seek the buffer to that spot. we have to seek based on the start
// because json.Decoder can read from the buffer more than is necessary to
// process the JSON data.
if _, err = buf.Seek(offset, 0); err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}
}
// now that we have our [][]*string of JSON data, we can pass it to get converted
// to an Arrow record batch and appended to our slice of batches
rec, err := jsonDataToArrow(ctx, bldr, rawData)
if err != nil {
return nil, err
}
defer rec.Release()
results = append(results, rec)
}
return array.NewRecordReader(schema, results)
}
ch := make(chan arrow.Record, bufferSize)
group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
ctx, cancelFn := context.WithCancel(ctx)
group.SetLimit(prefetchConcurrency)
defer func() {
if err != nil {
close(ch)
cancelFn()
}
}()
chs := make([]chan arrow.Record, len(batches))
rdr := &reader{
refCount: 1,
chs: chs,
err: nil,
cancelFn: cancelFn,
}
if len(batches) == 0 {
schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
if err != nil {
return nil, err
}
rdr.schema, _ = getTransformer(schema, ld, useHighPrecision)
return rdr, nil
}
r, err := batches[0].GetStream(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidState,
}
}
var recTransform recordTransformer
rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision)
group.Go(func() (err error) {
defer rr.Release()
defer func() {
err = errors.Join(err, r.Close())
}()
if len(batches) > 1 {
defer close(ch)
}
for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
ch <- rec
}
return rr.Err()
})
chs[0] = ch
lastChannelIndex := len(chs) - 1
go func() {
for i, b := range batches[1:] {
batch, batchIdx := b, i+1
chs[batchIdx] = make(chan arrow.Record, bufferSize)
group.Go(func() (err error) {
// close channels (except the last) so that Next can move on to the next channel properly
if batchIdx != lastChannelIndex {
defer close(chs[batchIdx])
}
rdr, err := batch.GetStream(ctx)
if err != nil {
return err
}
defer func() {
err = errors.Join(err, rdr.Close())
}()
rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc))
if err != nil {
return err
}
defer rr.Release()
for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
chs[batchIdx] <- rec
}
return rr.Err()
})
}
// place this here so that we always clean up, but they can't be in a
// separate goroutine. Otherwise we'll have a race condition between
// the call to wait and the calls to group.Go to kick off the jobs
// to perform the pre-fetching (GH-1283).
rdr.err = group.Wait()
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
}()
return rdr, nil
}