func newRecordReader()

in go/adbc/driver/flightsql/record_reader.go [50:165]


func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.Client, info *flight.FlightInfo, clCache gcache.Cache, bufferSize int, opts ...grpc.CallOption) (rdr array.RecordReader, err error) {
	endpoints := info.Endpoint
	var schema *arrow.Schema
	if len(endpoints) == 0 {
		if info.Schema == nil {
			return nil, adbc.Error{
				Msg:  "Server returned FlightInfo with no schema and no endpoints, cannot read stream",
				Code: adbc.StatusInternal,
			}
		}
		schema, err = flight.DeserializeSchema(info.Schema, alloc)
		if err != nil {
			return nil, adbc.Error{
				Msg:  "Server returned FlightInfo with invalid schema and no endpoints, cannot read stream",
				Code: adbc.StatusInternal,
			}
		}
		return array.NewRecordReader(schema, []arrow.Record{})
	}

	ch := make(chan arrow.Record, bufferSize)
	group, ctx := errgroup.WithContext(ctx)
	ctx, cancelFn := context.WithCancel(ctx)
	// We may mutate endpoints below
	numEndpoints := len(endpoints)

	defer func() {
		if err != nil {
			close(ch)
			cancelFn()
		}
	}()

	if info.Schema != nil {
		schema, err = flight.DeserializeSchema(info.Schema, alloc)
		if err != nil {
			return nil, adbc.Error{
				Msg:  err.Error(),
				Code: adbc.StatusInvalidState}
		}
	} else {
		rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...)
		if err != nil {
			return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location)
		}
		schema = rdr.Schema()
		group.Go(func() error {
			defer rdr.Release()
			if numEndpoints > 1 {
				defer close(ch)
			}

			for rdr.Next() && ctx.Err() == nil {
				rec := rdr.Record()
				rec.Retain()
				ch <- rec
			}
			return rdr.Err()
		})

		endpoints = endpoints[1:]
	}

	chs := make([]chan arrow.Record, numEndpoints)
	chs[0] = ch
	reader := &reader{
		refCount: 1,
		chs:      chs,
		err:      nil,
		cancelFn: cancelFn,
		schema:   schema,
	}

	lastChannelIndex := len(chs) - 1

	referenceSchema := utils.RemoveSchemaMetadata(schema)
	for i, ep := range endpoints {
		endpoint := ep
		endpointIndex := i
		chs[endpointIndex] = make(chan arrow.Record, bufferSize)
		group.Go(func() error {
			// Close channels (except the last) so that Next can move on to the next channel properly
			if endpointIndex != lastChannelIndex {
				defer close(chs[endpointIndex])
			}

			rdr, err := doGet(ctx, cl, endpoint, clCache, opts...)
			if err != nil {
				return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location)
			}
			defer rdr.Release()

			streamSchema := utils.RemoveSchemaMetadata(rdr.Schema())
			if !streamSchema.Equal(referenceSchema) {
				return fmt.Errorf("endpoint %d returned inconsistent schema: expected %s but got %s", endpointIndex, referenceSchema.String(), streamSchema.String())
			}

			for rdr.Next() && ctx.Err() == nil {
				rec := rdr.Record()
				rec.Retain()
				chs[endpointIndex] <- rec
			}

			return rdr.Err()
		})
	}

	go func() {
		reader.err = group.Wait()
		// Don't close the last channel until after the group is finished, so that
		// Next() can only return after reader.err may have been set
		close(chs[lastChannelIndex])
	}()

	return reader, nil
}