go/adbc/driver/flightsql/record_reader.go (165 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package flightsql
import (
"context"
"fmt"
"sync/atomic"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/utils"
"github.com/apache/arrow/go/v13/arrow"
"github.com/apache/arrow/go/v13/arrow/array"
"github.com/apache/arrow/go/v13/arrow/flight"
"github.com/apache/arrow/go/v13/arrow/flight/flightsql"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/bluele/gcache"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
)
type reader struct {
refCount int64
schema *arrow.Schema
chs []chan arrow.Record
curChIndex int
rec arrow.Record
err error
cancelFn context.CancelFunc
}
// kicks off a goroutine for each endpoint and returns a reader which
// gathers all of the records as they come in.
func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.Client, info *flight.FlightInfo, clCache gcache.Cache, bufferSize int, opts ...grpc.CallOption) (rdr array.RecordReader, err error) {
endpoints := info.Endpoint
var schema *arrow.Schema
if len(endpoints) == 0 {
if info.Schema == nil {
return nil, adbc.Error{
Msg: "Server returned FlightInfo with no schema and no endpoints, cannot read stream",
Code: adbc.StatusInternal,
}
}
schema, err = flight.DeserializeSchema(info.Schema, alloc)
if err != nil {
return nil, adbc.Error{
Msg: "Server returned FlightInfo with invalid schema and no endpoints, cannot read stream",
Code: adbc.StatusInternal,
}
}
return array.NewRecordReader(schema, []arrow.Record{})
}
ch := make(chan arrow.Record, bufferSize)
group, ctx := errgroup.WithContext(ctx)
ctx, cancelFn := context.WithCancel(ctx)
// We may mutate endpoints below
numEndpoints := len(endpoints)
defer func() {
if err != nil {
close(ch)
cancelFn()
}
}()
if info.Schema != nil {
schema, err = flight.DeserializeSchema(info.Schema, alloc)
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidState}
}
} else {
rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...)
if err != nil {
return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location)
}
schema = rdr.Schema()
group.Go(func() error {
defer rdr.Release()
if numEndpoints > 1 {
defer close(ch)
}
for rdr.Next() && ctx.Err() == nil {
rec := rdr.Record()
rec.Retain()
ch <- rec
}
return rdr.Err()
})
endpoints = endpoints[1:]
}
chs := make([]chan arrow.Record, numEndpoints)
chs[0] = ch
reader := &reader{
refCount: 1,
chs: chs,
err: nil,
cancelFn: cancelFn,
schema: schema,
}
lastChannelIndex := len(chs) - 1
referenceSchema := utils.RemoveSchemaMetadata(schema)
for i, ep := range endpoints {
endpoint := ep
endpointIndex := i
chs[endpointIndex] = make(chan arrow.Record, bufferSize)
group.Go(func() error {
// Close channels (except the last) so that Next can move on to the next channel properly
if endpointIndex != lastChannelIndex {
defer close(chs[endpointIndex])
}
rdr, err := doGet(ctx, cl, endpoint, clCache, opts...)
if err != nil {
return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location)
}
defer rdr.Release()
streamSchema := utils.RemoveSchemaMetadata(rdr.Schema())
if !streamSchema.Equal(referenceSchema) {
return fmt.Errorf("endpoint %d returned inconsistent schema: expected %s but got %s", endpointIndex, referenceSchema.String(), streamSchema.String())
}
for rdr.Next() && ctx.Err() == nil {
rec := rdr.Record()
rec.Retain()
chs[endpointIndex] <- rec
}
return rdr.Err()
})
}
go func() {
reader.err = group.Wait()
// Don't close the last channel until after the group is finished, so that
// Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
}()
return reader, nil
}
func (r *reader) Retain() {
atomic.AddInt64(&r.refCount, 1)
}
func (r *reader) Release() {
if atomic.AddInt64(&r.refCount, -1) == 0 {
if r.rec != nil {
r.rec.Release()
}
r.cancelFn()
for _, ch := range r.chs {
for rec := range ch {
rec.Release()
}
}
}
}
func (r *reader) Err() error {
return r.err
}
func (r *reader) Next() bool {
if r.rec != nil {
r.rec.Release()
r.rec = nil
}
if r.curChIndex >= len(r.chs) {
return false
}
var ok bool
for r.curChIndex < len(r.chs) {
if r.rec, ok = <-r.chs[r.curChIndex]; ok {
break
}
r.curChIndex++
}
return r.rec != nil
}
func (r *reader) Schema() *arrow.Schema {
return r.schema
}
func (r *reader) Record() arrow.Record {
return r.rec
}