arrow/internal/flight_integration/scenario.go (2,586 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package flight_integration
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"net"
"os"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/flight"
"github.com/apache/arrow-go/v18/arrow/flight/flightsql"
"github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow-go/v18/arrow/flight/session"
"github.com/apache/arrow-go/v18/arrow/internal/arrjson"
"github.com/apache/arrow-go/v18/arrow/ipc"
"github.com/apache/arrow-go/v18/arrow/memory"
"golang.org/x/xerrors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
)
type Scenario interface {
MakeServer(port int) flight.Server
RunClient(addr string, opts ...grpc.DialOption) error
}
func GetScenario(name string, args ...string) Scenario {
switch name {
case "auth:basic_proto":
return &authBasicProtoTester{}
case "middleware":
return &middlewareScenarioTester{}
case "ordered":
return &orderedScenarioTester{}
case "expiration_time:do_get":
return &expirationTimeDoGetScenarioTester{}
case "expiration_time:list_actions":
return &expirationTimeListActionsScenarioTester{}
case "expiration_time:cancel_flight_info":
return &expirationTimeCancelFlightInfoScenarioTester{}
case "expiration_time:renew_flight_endpoint":
return &expirationTimeRenewFlightEndpointScenarioTester{}
case "location:reuse_connection":
return &locationReuseConnectionScenarioTester{}
case "poll_flight_info":
return &pollFlightInfoScenarioTester{}
case "app_metadata_flight_info_endpoint":
return &appMetadataFlightInfoEndpointScenarioTester{}
case "flight_sql":
return &flightSqlScenarioTester{}
case "flight_sql:extension":
return &flightSqlExtensionScenarioTester{}
case "session_options":
return &sessionOptionsScenarioTester{}
case "flight_sql:ingestion":
return &flightSqlIngestionScenarioTester{}
case "":
if len(args) > 0 {
return &defaultIntegrationTester{path: args[0]}
}
return &defaultIntegrationTester{}
}
panic(fmt.Errorf("scenario not found: %s", name))
}
func initServer(port int, srv flight.Server) int {
srv.Init(fmt.Sprintf("0.0.0.0:%d", port))
_, p, _ := net.SplitHostPort(srv.Addr().String())
port, _ = strconv.Atoi(p)
return port
}
type integrationDataSet struct {
schema *arrow.Schema
chunks []arrow.Record
}
func consumeFlightLocation(ctx context.Context, loc *flight.Location, tkt *flight.Ticket, orig []arrow.Record, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(loc.GetUri(), nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
stream, err := client.DoGet(ctx, tkt)
if err != nil {
return err
}
rdr, err := flight.NewRecordReader(stream)
if err != nil {
return err
}
defer rdr.Release()
for i, chunk := range orig {
if !rdr.Next() {
return fmt.Errorf("got fewer batches than expected, received so far: %d, expected: %d", i, len(orig))
}
if !array.RecordEqual(chunk, rdr.Record()) {
return fmt.Errorf("batch %d doesn't match", i)
}
if string(rdr.LatestAppMetadata()) != strconv.Itoa(i) {
return fmt.Errorf("expected metadata value: %s, but got: %s", strconv.Itoa(i), string(rdr.LatestAppMetadata()))
}
}
if rdr.Next() {
return fmt.Errorf("got more batches than the expected: %d", len(orig))
}
return nil
}
type defaultIntegrationTester struct {
flight.BaseFlightServer
port int
path string
uploadedChunks map[string]integrationDataSet
}
func (s *defaultIntegrationTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
descr := &flight.FlightDescriptor{
Type: flight.DescriptorPATH,
Path: []string{s.path},
}
fmt.Println("Opening JSON file '", s.path, "'")
r, err := os.Open(s.path)
if err != nil {
return fmt.Errorf("could not open JSON file: %q: %w", s.path, err)
}
rdr, err := arrjson.NewReader(r)
if err != nil {
return fmt.Errorf("could not create JSON file reader from file: %q: %w", s.path, err)
}
dataSet := integrationDataSet{
chunks: make([]arrow.Record, 0),
schema: rdr.Schema(),
}
for {
rec, err := rdr.Read()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
defer rec.Release()
dataSet.chunks = append(dataSet.chunks, rec)
}
stream, err := client.DoPut(ctx)
if err != nil {
return err
}
wr := flight.NewRecordWriter(stream, ipc.WithSchema(dataSet.schema))
wr.SetFlightDescriptor(descr)
for i, rec := range dataSet.chunks {
metadata := []byte(strconv.Itoa(i))
if err := wr.WriteWithAppMetadata(rec, metadata); err != nil {
return err
}
pr, err := stream.Recv()
if err != nil {
return err
}
acked := pr.GetAppMetadata()
switch {
case len(acked) == 0:
return fmt.Errorf("expected metadata value: %s, but got nothing", string(metadata))
case !bytes.Equal(metadata, acked):
return fmt.Errorf("expected metadata value: %s, but got: %s", string(metadata), string(acked))
}
}
wr.Close()
if err := stream.CloseSend(); err != nil {
return err
}
for {
_, err = stream.Recv()
if err != nil {
if err != io.EOF {
return err
}
break
}
}
info, err := client.GetFlightInfo(ctx, descr)
if err != nil {
return err
}
if len(info.Endpoint) == 0 {
fmt.Fprintln(os.Stderr, "no endpoints returned from flight server.")
return fmt.Errorf("no endpoints returned from flight server")
}
for _, ep := range info.Endpoint {
if len(ep.Location) == 0 {
return fmt.Errorf("no locations returned from flight server")
}
for _, loc := range ep.Location {
consumeFlightLocation(ctx, loc, ep.Ticket, dataSet.chunks, opts...)
}
}
return nil
}
func (s *defaultIntegrationTester) MakeServer(port int) flight.Server {
s.uploadedChunks = make(map[string]integrationDataSet)
srv := flight.NewServerWithMiddleware(nil)
srv.RegisterFlightService(s)
s.port = initServer(port, srv)
return srv
}
func (s *defaultIntegrationTester) GetFlightInfo(ctx context.Context, in *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if in.Type == flight.DescriptorPATH {
if len(in.Path) == 0 {
return nil, status.Error(codes.InvalidArgument, "invalid path")
}
data, ok := s.uploadedChunks[in.Path[0]]
if !ok {
return nil, status.Errorf(codes.NotFound, "could not find flight: %s", in.Path[0])
}
flightData := &flight.FlightInfo{
Schema: flight.SerializeSchema(data.schema, memory.DefaultAllocator),
FlightDescriptor: in,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: []byte(in.Path[0])},
Location: []*flight.Location{{Uri: fmt.Sprintf("grpc+tcp://127.0.0.1:%d", s.port)}},
}},
TotalRecords: 0,
TotalBytes: -1,
}
for _, r := range data.chunks {
flightData.TotalRecords += r.NumRows()
}
return flightData, nil
}
return nil, status.Error(codes.Unimplemented, in.Type.String())
}
func (s *defaultIntegrationTester) DoGet(tkt *flight.Ticket, stream flight.FlightService_DoGetServer) error {
data, ok := s.uploadedChunks[string(tkt.Ticket)]
if !ok {
return status.Errorf(codes.NotFound, "could not find flight: %s", string(tkt.Ticket))
}
wr := flight.NewRecordWriter(stream, ipc.WithSchema(data.schema))
defer wr.Close()
for i, rec := range data.chunks {
wr.WriteWithAppMetadata(rec, []byte(strconv.Itoa(i)))
}
return nil
}
func (s *defaultIntegrationTester) DoPut(stream flight.FlightService_DoPutServer) error {
rdr, err := flight.NewRecordReader(stream)
if err != nil {
return status.Error(codes.Internal, err.Error())
}
var (
key string
dataset integrationDataSet
)
// creating the reader should have gotten the first message which would
// have the schema, which should have a populated flight descriptor
desc := rdr.LatestFlightDescriptor()
if desc.Type != flight.DescriptorPATH || len(desc.Path) < 1 {
return status.Error(codes.InvalidArgument, "must specify a path")
}
key = desc.Path[0]
dataset.schema = rdr.Schema()
dataset.chunks = make([]arrow.Record, 0)
for rdr.Next() {
rec := rdr.Record()
rec.Retain()
dataset.chunks = append(dataset.chunks, rec)
if len(rdr.LatestAppMetadata()) > 0 {
stream.Send(&flight.PutResult{AppMetadata: rdr.LatestAppMetadata()})
}
}
s.uploadedChunks[key] = dataset
return nil
}
func CheckActionResults(ctx context.Context, client flight.Client, action *flight.Action, results []string) error {
stream, err := client.DoAction(ctx, action)
if err != nil {
return err
}
defer stream.CloseSend()
for _, expected := range results {
res, err := stream.Recv()
if err != nil {
return err
}
actual := string(res.Body)
if expected != actual {
return fmt.Errorf("got wrong result: expected: %s, got: %s", expected, actual)
}
}
res, err := stream.Recv()
if res != nil || err != io.EOF {
return xerrors.New("action result stream had too many entries")
}
return nil
}
const (
authUsername = "arrow"
authPassword = "flight"
)
type authBasicValidator struct {
auth flight.BasicAuth
}
func (a *authBasicValidator) Authenticate(conn flight.AuthConn) error {
token, err := conn.Read()
if err != nil {
return err
}
var incoming flight.BasicAuth
if err = proto.Unmarshal(token, &incoming); err != nil {
return err
}
if incoming.Username != a.auth.Username || incoming.Password != a.auth.Password {
return status.Error(codes.Unauthenticated, "invalid token")
}
return conn.Send([]byte(a.auth.Username))
}
func (a *authBasicValidator) IsValid(token string) (interface{}, error) {
if token != a.auth.Username {
return nil, status.Error(codes.Unauthenticated, "invalid token")
}
return token, nil
}
type clientAuthBasic struct {
auth *flight.BasicAuth
token string
}
func (c *clientAuthBasic) Authenticate(_ context.Context, conn flight.AuthConn) error {
if c.auth != nil {
data, err := proto.Marshal(c.auth)
if err != nil {
return err
}
if err = conn.Send(data); err != nil {
return err
}
token, err := conn.Read()
c.token = string(token)
if err != io.EOF {
return err
}
}
return nil
}
func (c *clientAuthBasic) GetToken(context.Context) (string, error) {
return c.token, nil
}
type authBasicProtoTester struct {
flight.BaseFlightServer
}
func (s *authBasicProtoTester) RunClient(addr string, opts ...grpc.DialOption) error {
auth := &clientAuthBasic{}
client, err := flight.NewClientWithMiddleware(addr, auth, nil, opts...)
if err != nil {
return err
}
ctx := context.Background()
stream, err := client.DoAction(ctx, &flight.Action{})
if err != nil {
return err
}
// should fail unauthenticated
_, err = stream.Recv()
st, ok := status.FromError(err)
if !ok {
return err
}
if st.Code() != codes.Unauthenticated {
return fmt.Errorf("expected Unauthenticated, got %s", st.Code())
}
auth.auth = &flight.BasicAuth{Username: authUsername, Password: authPassword}
if err := client.Authenticate(ctx); err != nil {
return err
}
return CheckActionResults(ctx, client, &flight.Action{}, []string{authUsername})
}
func (s *authBasicProtoTester) MakeServer(port int) flight.Server {
s.SetAuthHandler(&authBasicValidator{
auth: flight.BasicAuth{Username: authUsername, Password: authPassword}})
srv := flight.NewServerWithMiddleware(nil)
srv.RegisterFlightService(s)
initServer(port, srv)
return srv
}
func (authBasicProtoTester) DoAction(_ *flight.Action, stream flight.FlightService_DoActionServer) error {
auth := flight.AuthFromContext(stream.Context())
stream.Send(&flight.Result{Body: []byte(auth.(string))})
return nil
}
type middlewareScenarioTester struct {
flight.BaseFlightServer
}
func (m *middlewareScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
tm := &testClientMiddleware{}
client, err := flight.NewClientWithMiddleware(addr, nil, []flight.ClientMiddleware{
flight.CreateClientMiddleware(tm)}, opts...)
if err != nil {
return err
}
ctx := context.Background()
// this call is expected to fail
_, err = client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD})
if err == nil {
return xerrors.New("expected call to fail")
}
if tm.received != "expected value" {
return fmt.Errorf("expected to receive header 'x-middleware: expected value', but instead got %s", tm.received)
}
fmt.Fprintln(os.Stderr, "Headers received successfully on failing call.")
tm.received = ""
_, err = client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("success")})
if err != nil {
return err
}
if tm.received != "expected value" {
return fmt.Errorf("expected to receive header 'x-middleware: expected value', but instead got %s", tm.received)
}
fmt.Fprintln(os.Stderr, "Headers received successfully on passing call.")
return nil
}
func (m *middlewareScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
flight.CreateServerMiddleware(testServerMiddleware{})})
srv.RegisterFlightService(m)
initServer(port, srv)
return srv
}
func (m *middlewareScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if desc.Type != flight.DescriptorCMD || string(desc.Cmd) != "success" {
return nil, status.Error(codes.Unknown, "unknown")
}
return &flight.FlightInfo{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{}, nil), memory.DefaultAllocator),
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: []byte("foo")},
Location: []*flight.Location{{Uri: "grpc+tcp://localhost:10010"}},
}},
TotalRecords: -1,
TotalBytes: -1,
}, nil
}
type orderedScenarioTester struct {
flight.BaseFlightServer
}
func (o *orderedScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
info, err := client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("ordered")})
if err != nil {
return err
}
if !info.GetOrdered() {
return fmt.Errorf("expected to server return FlightInfo.ordered = true")
}
var recs []arrow.Record
for _, ep := range info.Endpoint {
if len(ep.Location) != 0 {
return fmt.Errorf("expected to receive empty locations to use the original service: %s",
ep.Location)
}
stream, err := client.DoGet(ctx, ep.Ticket)
if err != nil {
return err
}
rdr, err := flight.NewRecordReader(stream)
if err != nil {
return err
}
defer rdr.Release()
for rdr.Next() {
record := rdr.Record()
record.Retain()
defer record.Release()
recs = append(recs, record)
}
if rdr.Err() != nil {
return rdr.Err()
}
}
// Build expected records
mem := memory.DefaultAllocator
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Int32},
},
nil,
)
expected_table, _ := array.TableFromJSON(mem, schema, []string{
`[
{"number": 1},
{"number": 2},
{"number": 3}
]`,
`[
{"number": 10},
{"number": 20},
{"number": 30}
]`,
`[
{"number": 100},
{"number": 200},
{"number": 300}
]`,
})
defer expected_table.Release()
table := array.NewTableFromRecords(schema, recs)
defer table.Release()
if !array.TableEqual(table, expected_table) {
return fmt.Errorf("read data isn't expected\n"+
"Expected:\n"+
"%s\n"+
"num-rows: %d\n"+
"num-cols: %d\n"+
"Actual:\n"+
"%s\n"+
"num-rows: %d\n"+
"num-cols: %d",
expected_table.Schema(),
expected_table.NumRows(),
expected_table.NumCols(),
table.Schema(),
table.NumRows(),
table.NumCols())
}
return nil
}
func (o *orderedScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil)
srv.RegisterFlightService(o)
initServer(port, srv)
return srv
}
func (o *orderedScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
ordered := desc.Type == flight.DescriptorCMD && string(desc.Cmd) == "ordered"
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Int32},
},
nil,
)
return &flight.FlightInfo{
Schema: flight.SerializeSchema(schema, memory.DefaultAllocator),
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{
{
Ticket: &flight.Ticket{Ticket: []byte("1")},
Location: []*flight.Location{},
},
{
Ticket: &flight.Ticket{Ticket: []byte("2")},
Location: []*flight.Location{},
},
{
Ticket: &flight.Ticket{Ticket: []byte("3")},
Location: []*flight.Location{},
},
},
TotalRecords: -1,
TotalBytes: -1,
Ordered: ordered,
}, nil
}
func (o *orderedScenarioTester) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Int32},
},
nil,
)
b := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer b.Release()
if string(tkt.GetTicket()) == "1" {
b.Field(0).(*array.Int32Builder).AppendValues([]int32{1, 2, 3}, nil)
} else if string(tkt.GetTicket()) == "2" {
b.Field(0).(*array.Int32Builder).AppendValues([]int32{10, 20, 30}, nil)
} else if string(tkt.GetTicket()) == "3" {
b.Field(0).(*array.Int32Builder).AppendValues([]int32{100, 200, 300}, nil)
}
w := flight.NewRecordWriter(fs, ipc.WithSchema(schema))
rec := b.NewRecord()
defer rec.Release()
w.Write(rec)
return nil
}
type expirationTimeEndpointStatus struct {
expirationTime *time.Time
numGets uint32
cancelled bool
}
type expirationTimeScenarioTester struct {
flight.BaseFlightServer
statuses map[int]expirationTimeEndpointStatus
}
func (tester *expirationTimeScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil)
srv.RegisterFlightService(tester)
initServer(port, srv)
return srv
}
func (tester *expirationTimeScenarioTester) AppendGetFlightInfo(endpoints []*flight.FlightEndpoint, ticket string, expirationTime *time.Time) []*flight.FlightEndpoint {
index := len(tester.statuses)
endpoint := flight.FlightEndpoint{
Ticket: &flight.Ticket{Ticket: []byte(strconv.Itoa(index) + ": " + ticket)},
Location: []*flight.Location{},
}
if expirationTime != nil {
endpoint.ExpirationTime = timestamppb.New(*expirationTime)
}
endpoints = append(endpoints, &endpoint)
tester.statuses[index] = expirationTimeEndpointStatus{
expirationTime: expirationTime,
numGets: 0,
cancelled: false,
}
return endpoints
}
func (tester *expirationTimeScenarioTester) ExtractIndexFromTicket(ticket string) (int, error) {
indexString := strings.SplitN(ticket, ":", 2)[0]
index, err := strconv.Atoi(indexString)
if err != nil {
return 0, fmt.Errorf("invalid flight: no index: %s: %s", ticket, err)
}
if index >= len(tester.statuses) {
return 0, fmt.Errorf("invalid flight: out of index: %s", ticket)
}
return index, nil
}
func (tester *expirationTimeScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
tester.statuses = make(map[int]expirationTimeEndpointStatus)
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Uint32},
},
nil,
)
var endpoints []*flight.FlightEndpoint
endpoints = tester.AppendGetFlightInfo(endpoints, "No expiration time", nil)
expirationTime5 := time.Now().Add(time.Second * 5)
endpoints = tester.AppendGetFlightInfo(endpoints, "5 seconds", &expirationTime5)
expirationTime6 := time.Now().Add(time.Second * 6)
endpoints = tester.AppendGetFlightInfo(endpoints, "6 seconds", &expirationTime6)
return &flight.FlightInfo{
Schema: flight.SerializeSchema(schema, memory.DefaultAllocator),
FlightDescriptor: desc,
Endpoint: endpoints,
TotalRecords: -1,
TotalBytes: -1,
}, nil
}
func (tester *expirationTimeScenarioTester) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
ticket := string(tkt.GetTicket())
index, err := tester.ExtractIndexFromTicket(ticket)
if err != nil {
return err
}
st := tester.statuses[index]
if st.cancelled {
return status.Errorf(codes.InvalidArgument,
"Invalid flight: cancelled: %s", ticket)
}
if st.expirationTime == nil {
if st.numGets > 0 {
return status.Errorf(codes.InvalidArgument,
"Invalid flight: "+
"can't read multiple times: %s", ticket)
}
} else {
availableDuration := time.Until(*st.expirationTime)
if availableDuration < 0 {
return status.Errorf(codes.InvalidArgument,
"Invalid flight: expired: %s", ticket)
}
}
st.numGets++
tester.statuses[index] = st
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Uint32},
},
nil,
)
b := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer b.Release()
b.Field(0).(*array.Uint32Builder).AppendValues([]uint32{uint32(index)}, nil)
w := flight.NewRecordWriter(fs, ipc.WithSchema(schema))
rec := b.NewRecord()
defer rec.Release()
w.Write(rec)
return nil
}
func (tester *expirationTimeScenarioTester) ListActions(_ *flight.Empty, stream flight.FlightService_ListActionsServer) error {
actions := []string{
flight.CancelFlightInfoActionType,
flight.RenewFlightEndpointActionType,
}
for _, a := range actions {
if err := stream.Send(&flight.ActionType{Type: a}); err != nil {
return err
}
}
return nil
}
func packActionResult(msg proto.Message) (*flight.Result, error) {
ret := &flight.Result{}
var err error
if ret.Body, err = proto.Marshal(msg); err != nil {
return nil, fmt.Errorf("%w: unable to marshal final response", err)
}
return ret, nil
}
func (tester *expirationTimeScenarioTester) DoAction(cmd *flight.Action, stream flight.FlightService_DoActionServer) error {
switch cmd.Type {
case flight.CancelFlightInfoActionType:
var request flight.CancelFlightInfoRequest
if err := proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error())
}
cancelStatus := flight.CancelStatusUnspecified
for _, ep := range request.Info.Endpoint {
ticket := string(ep.Ticket.Ticket)
index, err := tester.ExtractIndexFromTicket(ticket)
if err == nil {
st := tester.statuses[index]
if st.cancelled {
cancelStatus = flight.CancelStatusNotCancellable
} else {
st.cancelled = true
if cancelStatus == flight.CancelStatusUnspecified {
cancelStatus = flight.CancelStatusCancelled
}
tester.statuses[index] = st
}
} else {
cancelStatus = flight.CancelStatusNotCancellable
}
}
result := flight.CancelFlightInfoResult{Status: cancelStatus}
out, err := packActionResult(&result)
if err != nil {
return err
}
if err = stream.Send(out); err != nil {
return err
}
return nil
case flight.RenewFlightEndpointActionType:
var request flight.RenewFlightEndpointRequest
if err := proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error())
}
endpoint := request.Endpoint
ticket := string(endpoint.Ticket.Ticket)
index, err := tester.ExtractIndexFromTicket(ticket)
if err != nil {
return err
}
endpoint.Ticket.Ticket = []byte(string(endpoint.Ticket.Ticket) + ": renewed (+ 10 seconds)")
renewedExpirationTime := time.Now().Add(time.Second * 10)
endpoint.ExpirationTime = timestamppb.New(renewedExpirationTime)
st := tester.statuses[index]
st.expirationTime = &renewedExpirationTime
tester.statuses[index] = st
out, err := packActionResult(endpoint)
if err != nil {
return err
}
if err = stream.Send(out); err != nil {
return err
}
return nil
default:
return status.Errorf(codes.InvalidArgument, "unsupported action: %s", cmd.Type)
}
}
type expirationTimeDoGetScenarioTester struct {
expirationTimeScenarioTester
}
func (tester *expirationTimeDoGetScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
info, err := client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("expiration_time")})
if err != nil {
return err
}
var recs []arrow.Record
for _, ep := range info.Endpoint {
if len(recs) == 0 {
if ep.ExpirationTime != nil {
return fmt.Errorf("endpoints[0] must not have " +
"expiration time")
}
} else {
if ep.ExpirationTime == nil {
return fmt.Errorf("endpoints[1] must have " +
"expiration time")
}
}
if len(ep.Location) != 0 {
return fmt.Errorf("expected to receive empty locations to use the original service: %s",
ep.Location)
}
stream, err := client.DoGet(ctx, ep.Ticket)
if err != nil {
return err
}
rdr, err := flight.NewRecordReader(stream)
if err != nil {
return err
}
defer rdr.Release()
for rdr.Next() {
record := rdr.Record()
record.Retain()
defer record.Release()
recs = append(recs, record)
}
if rdr.Err() != nil {
return rdr.Err()
}
}
// Build expected records
mem := memory.DefaultAllocator
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Uint32},
},
nil,
)
expectedTable, _ := array.TableFromJSON(mem, schema, []string{
`[{"number": 0}]`,
`[{"number": 1}]`,
`[{"number": 2}]`,
})
defer expectedTable.Release()
table := array.NewTableFromRecords(schema, recs)
defer table.Release()
if !array.TableEqual(table, expectedTable) {
return fmt.Errorf("read data isn't expected\n"+
"Expected:\n"+
"%s\n"+
"numRows: %d\n"+
"numCols: %d\n"+
"Actual:\n"+
"%s\n"+
"numRows: %d\n"+
"numCols: %d",
expectedTable.Schema(),
expectedTable.NumRows(),
expectedTable.NumCols(),
table.Schema(),
table.NumRows(),
table.NumCols())
}
return nil
}
type expirationTimeListActionsScenarioTester struct {
expirationTimeScenarioTester
}
func (tester *expirationTimeListActionsScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
stream, err := client.ListActions(ctx, &flight.Empty{})
if err != nil {
return err
}
var actionTypeNames []string
for {
actionType, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return err
}
actionTypeNames = append(actionTypeNames, actionType.Type)
}
sort.Strings(actionTypeNames)
expectedActionTypeNames := []string{
"CancelFlightInfo",
"RenewFlightEndpoint",
}
if !reflect.DeepEqual(actionTypeNames, expectedActionTypeNames) {
return fmt.Errorf("action types aren't expected\n"+
"Expected:\n"+
"%s\n"+
"Actual:\n"+
"%s",
expectedActionTypeNames,
actionTypeNames)
}
return nil
}
type expirationTimeCancelFlightInfoScenarioTester struct {
expirationTimeScenarioTester
}
func (tester *expirationTimeCancelFlightInfoScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
info, err := client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("expiration_time")})
if err != nil {
return err
}
request := flight.CancelFlightInfoRequest{Info: info}
result, err := client.CancelFlightInfo(ctx, &request)
if err != nil && !errors.Is(err, io.EOF) {
return err
}
if result.Status != flight.CancelStatusCancelled {
return fmt.Errorf("invalid: CancelFlightInfo must return CANCEL_STATUS_CANCELLED: %s", result.Status)
}
for _, ep := range info.Endpoint {
stream, err := client.DoGet(ctx, ep.Ticket)
if err != nil {
return err
}
rdr, err := flight.NewRecordReader(stream)
if err == nil {
rdr.Release()
return fmt.Errorf("invalid: DoGet after CancelFlightInfo must be failed")
}
}
return nil
}
type expirationTimeRenewFlightEndpointScenarioTester struct {
expirationTimeScenarioTester
}
func (tester *expirationTimeRenewFlightEndpointScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
info, err := client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("expiration_time")})
if err != nil {
return err
}
// Renew all endpoints that have expiration time
for _, ep := range info.Endpoint {
if ep.ExpirationTime == nil {
continue
}
expirationTime := ep.ExpirationTime.AsTime()
request := flight.RenewFlightEndpointRequest{Endpoint: ep}
renewedEndpoint, err := client.RenewFlightEndpoint(ctx, &request)
if err != nil {
return err
}
if renewedEndpoint.ExpirationTime == nil {
return fmt.Errorf("renewed endpoint must have expiration time: %s",
renewedEndpoint)
}
renewedExpirationTime := renewedEndpoint.ExpirationTime.AsTime()
if renewedExpirationTime.Sub(expirationTime) <= 0 {
return fmt.Errorf("renewed endpoint must have newer expiration time\n"+
"Original: %s\nRenewed: %s",
ep, renewedEndpoint)
}
}
return nil
}
type locationReuseConnectionScenarioTester struct {
flight.BaseFlightServer
}
func (m *locationReuseConnectionScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
return &flight.FlightInfo{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{}, nil), memory.DefaultAllocator),
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: []byte("reuse")},
Location: []*flight.Location{{Uri: flight.LocationReuseConnection}},
}},
TotalRecords: -1,
TotalBytes: -1,
}, nil
}
func (tester *locationReuseConnectionScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil)
srv.RegisterFlightService(tester)
initServer(port, srv)
return srv
}
func (tester *locationReuseConnectionScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
info, err := client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.DescriptorCMD, Cmd: []byte("reuse")})
if err != nil {
return err
}
if len(info.Endpoint) != 1 {
return fmt.Errorf("expected 1 endpoint, got %d", len(info.Endpoint))
}
endpoint := info.Endpoint[0]
if len(endpoint.Location) != 1 {
return fmt.Errorf("expected 1 location, got %d", len(endpoint.Location))
} else if endpoint.Location[0].Uri != flight.LocationReuseConnection {
return fmt.Errorf("expected %s, got %s", flight.LocationReuseConnection, endpoint.Location[0].Uri)
}
return nil
}
type pollFlightInfoScenarioTester struct {
flight.BaseFlightServer
}
func (tester *pollFlightInfoScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil)
srv.RegisterFlightService(tester)
initServer(port, srv)
return srv
}
func (tester *pollFlightInfoScenarioTester) PollFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Uint32},
},
nil,
)
endpoints := []*flight.FlightEndpoint{
{
Ticket: &flight.Ticket{Ticket: []byte("long-running query")},
Location: []*flight.Location{},
},
}
info := &flight.FlightInfo{
Schema: flight.SerializeSchema(schema, memory.DefaultAllocator),
FlightDescriptor: desc,
Endpoint: endpoints,
TotalRecords: -1,
TotalBytes: -1,
}
pollDesc := flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: []byte("poll"),
}
if desc.Type == pollDesc.Type && string(desc.Cmd) == string(pollDesc.Cmd) {
progress := float64(1.0)
return &flight.PollInfo{
Info: info,
FlightDescriptor: nil,
Progress: &progress,
ExpirationTime: nil,
}, nil
} else {
progress := float64(0.1)
return &flight.PollInfo{
Info: info,
FlightDescriptor: &pollDesc,
Progress: &progress,
ExpirationTime: timestamppb.New(time.Now().Add(time.Second * 10)),
}, nil
}
}
func (tester *pollFlightInfoScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
desc := flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: []byte("heavy query"),
}
info, err := client.PollFlightInfo(ctx, &desc)
if err != nil {
return err
}
switch {
case info.FlightDescriptor == nil:
return fmt.Errorf("description is missing: %s", info.String())
case info.Progress == nil:
return fmt.Errorf("progress is missing: %s", info.String())
case !(0.0 <= *info.Progress && *info.Progress <= 1.0):
return fmt.Errorf("invalid progress: %s", info.String())
case info.ExpirationTime == nil:
return fmt.Errorf("expiration time is missing: %s", info.String())
}
info, err = client.PollFlightInfo(ctx, info.FlightDescriptor)
if err != nil {
return err
}
switch {
case info.FlightDescriptor != nil:
return fmt.Errorf("retried but no finished yet: %s", info.String())
case info.Progress == nil:
return fmt.Errorf("progress is missing in finished query: %s",
info.String())
case math.Abs(*info.Progress-1.0) > 1e-5:
return fmt.Errorf("progress for finished query isn't 1.0: %s",
info.String())
case info.ExpirationTime != nil:
return fmt.Errorf("expiration time must not be set for finished query: %s",
info.String())
}
return nil
}
type appMetadataFlightInfoEndpointScenarioTester struct {
flight.BaseFlightServer
}
func (tester *appMetadataFlightInfoEndpointScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil)
srv.RegisterFlightService(tester)
initServer(port, srv)
return srv
}
func (tester *appMetadataFlightInfoEndpointScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Uint32},
},
nil,
)
if desc.Type != flight.DescriptorCMD {
return nil, fmt.Errorf("%w: should have received CMD descriptor", arrow.ErrInvalid)
}
endpoints := []*flight.FlightEndpoint{{AppMetadata: desc.Cmd}}
return &flight.FlightInfo{
Schema: flight.SerializeSchema(schema, memory.DefaultAllocator),
FlightDescriptor: desc,
Endpoint: endpoints,
TotalRecords: -1,
TotalBytes: -1,
AppMetadata: desc.Cmd,
}, nil
}
func (tester *appMetadataFlightInfoEndpointScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
desc := flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: []byte("foobar"),
}
info, err := client.GetFlightInfo(ctx, &desc)
if err != nil {
return err
}
switch {
case !bytes.Equal(desc.Cmd, info.AppMetadata):
return fmt.Errorf("invalid flight info app_metadata: %s, expected: %s", info.AppMetadata, desc.Cmd)
case len(info.Endpoint) != 1:
return fmt.Errorf("expected exactly 1 flight endpoint, got: %d", len(info.Endpoint))
case !bytes.Equal(desc.Cmd, info.Endpoint[0].AppMetadata):
return fmt.Errorf("invalid flight endpoint app_metadata: %s, expected: %s", info.Endpoint[0].AppMetadata, desc.Cmd)
}
return nil
}
const (
updateStatementExpectedRows int64 = 10000
updateStatementWithTransactionExpectedRows int64 = 15000
updatePreparedStatementExpectedRows int64 = 20000
updatePreparedStatementWithTransactionExpectedRows int64 = 25000
ingestStatementExpectedRows int64 = 3
)
type flightSqlScenarioTester struct {
flightsql.BaseServer
}
func (m *flightSqlScenarioTester) flightInfoForCommand(desc *flight.FlightDescriptor, schema *arrow.Schema) *flight.FlightInfo {
return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{
{Ticket: &flight.Ticket{Ticket: desc.Cmd}},
},
Schema: flight.SerializeSchema(schema, memory.DefaultAllocator),
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
}
}
func (m *flightSqlScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil)
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerSql, false)
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerSubstrait, true)
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerSubstraitMinVersion, "min_version")
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerSubstraitMaxVersion, "max_version")
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerTransaction, int32(flightsql.SqlTransactionSavepoint))
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerCancel, true)
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerStatementTimeout, int32(42))
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerTransactionTimeout, int32(7))
srv.RegisterFlightService(flightsql.NewFlightServer(m))
initServer(port, srv)
return srv
}
func assertEq(expected, actual interface{}) error {
v := reflect.Indirect(reflect.ValueOf(actual))
if !reflect.DeepEqual(expected, v.Interface()) {
return fmt.Errorf("expected: '%s', got: '%s'", expected, actual)
}
return nil
}
func (m *flightSqlScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flightsql.NewClient(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
if err := m.ValidateMetadataRetrieval(client); err != nil {
return err
}
if err := m.ValidateStatementExecution(client); err != nil {
return err
}
return m.ValidatePreparedStatementExecution(client)
}
func (m *flightSqlScenarioTester) validate(expected *arrow.Schema, result *flight.FlightInfo, client *flightsql.Client) error {
rdr, err := client.DoGet(context.Background(), result.Endpoint[0].Ticket)
if err != nil {
return err
}
if !expected.Equal(rdr.Schema()) {
return fmt.Errorf("expected: %s, got: %s", expected, rdr.Schema())
}
for {
_, err := rdr.Read()
if err == io.EOF {
break
}
if err != nil {
return err
}
}
return nil
}
func (m *flightSqlScenarioTester) validateSchema(expected *arrow.Schema, result *flight.SchemaResult) error {
schema, err := flight.DeserializeSchema(result.GetSchema(), memory.DefaultAllocator)
if err != nil {
return err
}
if !expected.Equal(schema) {
return fmt.Errorf("expected: %s, got: %s", expected, schema)
}
return nil
}
func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Client) error {
var (
catalog = "catalog"
dbSchemaFilterPattern = "db_schema_filter_pattern"
tableFilterPattern = "table_filter_pattern"
table = "table"
dbSchema = "db_schema"
tableTypes = []string{"table", "view"}
ref = flightsql.TableRef{Catalog: &catalog, DBSchema: &dbSchema, Table: table}
pkRef = flightsql.TableRef{Catalog: proto.String("pk_catalog"), DBSchema: proto.String("pk_db_schema"), Table: "pk_table"}
fkRef = flightsql.TableRef{Catalog: proto.String("fk_catalog"), DBSchema: proto.String("fk_db_schema"), Table: "fk_table"}
ctx = context.Background()
)
info, err := client.GetCatalogs(ctx)
if err != nil {
return err
}
if err := m.validate(schema_ref.Catalogs, info, client); err != nil {
return err
}
schema, err := client.GetCatalogsSchema(ctx)
if err != nil {
return err
}
if err := m.validateSchema(schema_ref.Catalogs, schema); err != nil {
return err
}
info, err = client.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{Catalog: &catalog, DbSchemaFilterPattern: &dbSchemaFilterPattern})
if err != nil {
return err
}
if err = m.validate(schema_ref.DBSchemas, info, client); err != nil {
return err
}
schema, err = client.GetDBSchemasSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.DBSchemas, schema); err != nil {
return err
}
info, err = client.GetTables(ctx, &flightsql.GetTablesOpts{Catalog: &catalog, DbSchemaFilterPattern: &dbSchemaFilterPattern, TableNameFilterPattern: &tableFilterPattern, IncludeSchema: true, TableTypes: tableTypes})
if err != nil {
return err
}
if err = m.validate(schema_ref.TablesWithIncludedSchema, info, client); err != nil {
return err
}
schema, err = client.GetTablesSchema(ctx, &flightsql.GetTablesOpts{IncludeSchema: true})
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.TablesWithIncludedSchema, schema); err != nil {
return err
}
schema, err = client.GetTablesSchema(ctx, &flightsql.GetTablesOpts{IncludeSchema: false})
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.Tables, schema); err != nil {
return err
}
info, err = client.GetTableTypes(ctx)
if err != nil {
return err
}
if err = m.validate(schema_ref.TableTypes, info, client); err != nil {
return err
}
schema, err = client.GetTableTypesSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.TableTypes, schema); err != nil {
return err
}
info, err = client.GetPrimaryKeys(ctx, ref)
if err != nil {
return err
}
if err = m.validate(schema_ref.PrimaryKeys, info, client); err != nil {
return err
}
schema, err = client.GetPrimaryKeysSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.PrimaryKeys, schema); err != nil {
return err
}
info, err = client.GetExportedKeys(ctx, ref)
if err != nil {
return err
}
if err = m.validate(schema_ref.ExportedKeys, info, client); err != nil {
return err
}
schema, err = client.GetExportedKeysSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.ExportedKeys, schema); err != nil {
return err
}
info, err = client.GetImportedKeys(ctx, ref)
if err != nil {
return err
}
if err = m.validate(schema_ref.ImportedKeys, info, client); err != nil {
return err
}
schema, err = client.GetImportedKeysSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.ImportedKeys, schema); err != nil {
return err
}
info, err = client.GetCrossReference(ctx, pkRef, fkRef)
if err != nil {
return err
}
if err = m.validate(schema_ref.CrossReference, info, client); err != nil {
return err
}
schema, err = client.GetCrossReferenceSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.CrossReference, schema); err != nil {
return err
}
info, err = client.GetXdbcTypeInfo(ctx, nil)
if err != nil {
return err
}
if err = m.validate(schema_ref.XdbcTypeInfo, info, client); err != nil {
return err
}
schema, err = client.GetXdbcTypeInfoSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.XdbcTypeInfo, schema); err != nil {
return err
}
info, err = client.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerName, flightsql.SqlInfoFlightSqlServerReadOnly})
if err != nil {
return err
}
if err = m.validate(schema_ref.SqlInfo, info, client); err != nil {
return err
}
schema, err = client.GetSqlInfoSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(schema_ref.SqlInfo, schema); err != nil {
return err
}
return nil
}
func (m *flightSqlScenarioTester) ValidateStatementExecution(client *flightsql.Client) error {
ctx := context.Background()
info, err := client.Execute(ctx, "SELECT STATEMENT")
if err != nil {
return err
}
if err = m.validate(getQuerySchema(), info, client); err != nil {
return err
}
schema, err := client.GetExecuteSchema(ctx, "SELECT STATEMENT")
if err != nil {
return err
}
if err = m.validateSchema(getQuerySchema(), schema); err != nil {
return err
}
updateResult, err := client.ExecuteUpdate(ctx, "UPDATE STATEMENT")
if err != nil {
return err
}
if updateResult != updateStatementExpectedRows {
return fmt.Errorf("expected 'UPDATE STATEMENT' return %d got %d", updateStatementExpectedRows, updateResult)
}
return nil
}
func (m *flightSqlScenarioTester) ValidatePreparedStatementExecution(client *flightsql.Client) error {
ctx := context.Background()
prepared, err := client.Prepare(ctx, "SELECT PREPARED STATEMENT")
if err != nil {
return err
}
arr, _, _ := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64, strings.NewReader("[1]"))
defer arr.Release()
params := array.NewRecord(getQuerySchema(), []arrow.Array{arr}, 1)
defer params.Release()
prepared.SetParameters(params)
info, err := prepared.Execute(ctx)
if err != nil {
return err
}
if err = m.validate(getQuerySchema(), info, client); err != nil {
return err
}
schema, err := prepared.GetSchema(ctx)
if err != nil {
return err
}
if err = m.validateSchema(getQuerySchema(), schema); err != nil {
return err
}
if err = prepared.Close(ctx); err != nil {
return err
}
updatePrepared, err := client.Prepare(ctx, "UPDATE PREPARED STATEMENT")
if err != nil {
return err
}
updateResult, err := updatePrepared.ExecuteUpdate(ctx)
if err != nil {
return err
}
if updateResult != updatePreparedStatementExpectedRows {
return fmt.Errorf("expected 'UPDATE STATEMENT' return %d got %d", updatePreparedStatementExpectedRows, updateResult)
}
return updatePrepared.Close(ctx)
}
func (m *flightSqlScenarioTester) doGetForTestCase(schema *arrow.Schema) chan flight.StreamChunk {
ch := make(chan flight.StreamChunk)
close(ch)
return ch
}
func (m *flightSqlScenarioTester) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq(selectStatement, cmd.GetQuery()); err != nil {
return nil, err
}
var (
ticket []byte
schema *arrow.Schema
)
if len(cmd.GetTransactionId()) == 0 {
ticket = []byte("SELECT STATEMENT HANDLE")
schema = getQuerySchema()
} else {
ticket = []byte("SELECT STATEMENT WITH TXN HANDLE")
schema = getQueryWithTransactionSchema()
}
handle, err := flightsql.CreateStatementQueryTicket(ticket)
if err != nil {
return nil, err
}
return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{
{Ticket: &flight.Ticket{Ticket: handle}},
},
Schema: flight.SerializeSchema(schema, memory.DefaultAllocator),
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
}, nil
}
func (m *flightSqlScenarioTester) GetFlightInfoSubstraitPlan(ctx context.Context, cmd flightsql.StatementSubstraitPlan, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq([]byte(substraitPlanText), cmd.GetPlan().Plan); err != nil {
return nil, fmt.Errorf("%w: unexpected plan in GetFlightInfoSubstraitPlan", err)
}
if err := assertEq(substraitPlanVersion, cmd.GetPlan().Version); err != nil {
return nil, fmt.Errorf("%w: unexpected version in GetFlightInfoSubstraitPlan", err)
}
var (
ticket []byte
schema *arrow.Schema
)
if len(cmd.GetTransactionId()) == 0 {
ticket = []byte("PLAN HANDLE")
schema = getQuerySchema()
} else {
ticket = []byte("PLAN WITH TXN HANDLE")
schema = getQueryWithTransactionSchema()
}
handle, err := flightsql.CreateStatementQueryTicket(ticket)
if err != nil {
return nil, err
}
return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{
{Ticket: &flight.Ticket{Ticket: handle}},
},
Schema: flight.SerializeSchema(schema, memory.DefaultAllocator),
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
}, nil
}
func (m *flightSqlScenarioTester) GetSchemaStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) {
if err := assertEq(selectStatement, cmd.GetQuery()); err != nil {
return nil, fmt.Errorf("%w: unexpected statement in GetSchemaStatement", err)
}
if len(cmd.GetTransactionId()) == 0 {
return &flight.SchemaResult{Schema: flight.SerializeSchema(getQuerySchema(), memory.DefaultAllocator)}, nil
}
return &flight.SchemaResult{Schema: flight.SerializeSchema(getQueryWithTransactionSchema(), memory.DefaultAllocator)}, nil
}
func (m *flightSqlScenarioTester) GetSchemaSubstraitPlan(ctx context.Context, cmd flightsql.StatementSubstraitPlan, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) {
if err := assertEq([]byte(substraitPlanText), cmd.GetPlan().Plan); err != nil {
return nil, fmt.Errorf("%w: unexpected plan in GetFlightInfoSubstraitPlan", err)
}
if err := assertEq(substraitPlanVersion, cmd.GetPlan().Version); err != nil {
return nil, fmt.Errorf("%w: unexpected version in GetFlightInfoSubstraitPlan", err)
}
if len(cmd.GetTransactionId()) == 0 {
return &flight.SchemaResult{Schema: flight.SerializeSchema(getQuerySchema(), memory.DefaultAllocator)}, nil
}
return &flight.SchemaResult{Schema: flight.SerializeSchema(getQueryWithTransactionSchema(), memory.DefaultAllocator)}, nil
}
func (m *flightSqlScenarioTester) DoGetStatement(ctx context.Context, cmd flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
switch string(cmd.GetStatementHandle()) {
case "SELECT STATEMENT HANDLE", "PLAN HANDLE":
return getQuerySchema(), m.doGetForTestCase(getQuerySchema()), nil
case "SELECT STATEMENT WITH TXN HANDLE", "PLAN WITH TXN HANDLE":
return getQueryWithTransactionSchema(), m.doGetForTestCase(getQueryWithTransactionSchema()), nil
}
return nil, nil, fmt.Errorf("%w: unknown handle %s", arrow.ErrInvalid, string(cmd.GetStatementHandle()))
}
func (m *flightSqlScenarioTester) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
switch string(cmd.GetPreparedStatementHandle()) {
case "SELECT PREPARED STATEMENT HANDLE", "PLAN HANDLE":
return m.flightInfoForCommand(desc, getQuerySchema()), nil
case "SELECT PREPARED STATEMENT WITH TXN HANDLE", "PLAN WITH TXN HANDLE":
return m.flightInfoForCommand(desc, getQueryWithTransactionSchema()), nil
}
return nil, fmt.Errorf("%w: invalid handle for GetFlightInfoPreparedStatement %s",
arrow.ErrInvalid, string(cmd.GetPreparedStatementHandle()))
}
func (m *flightSqlScenarioTester) GetSchemaPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) {
switch string(cmd.GetPreparedStatementHandle()) {
case "SELECT PREPARED STATEMENT HANDLE", "PLAN HANDLE":
return &flight.SchemaResult{Schema: flight.SerializeSchema(getQuerySchema(), memory.DefaultAllocator)}, nil
case "SELECT PREPARED STATEMENT WITH TXN HANDLE", "PLAN WITH TXN HANDLE":
return &flight.SchemaResult{Schema: flight.SerializeSchema(getQueryWithTransactionSchema(), memory.DefaultAllocator)}, nil
}
return nil, fmt.Errorf("%w: invalid handle for GetSchemaPreparedStatement %s",
arrow.ErrInvalid, string(cmd.GetPreparedStatementHandle()))
}
func (m *flightSqlScenarioTester) DoGetPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) {
switch string(cmd.GetPreparedStatementHandle()) {
case "SELECT PREPARED STATEMENT HANDLE", "PLAN HANDLE":
return getQuerySchema(), m.doGetForTestCase(getQuerySchema()), nil
case "SELECT PREPARED STATEMENT WITH TXN HANDLE", "PLAN WITH TXN HANDLE":
return getQueryWithTransactionSchema(), m.doGetForTestCase(getQueryWithTransactionSchema()), nil
}
return nil, nil, fmt.Errorf("%w: invalid handle: %s",
arrow.ErrInvalid, string(cmd.GetPreparedStatementHandle()))
}
func (m *flightSqlScenarioTester) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
return m.flightInfoForCommand(desc, schema_ref.Catalogs), nil
}
func (m *flightSqlScenarioTester) DoGetCatalogs(_ context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.Catalogs, m.doGetForTestCase(schema_ref.Catalogs), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoXdbcTypeInfo(_ context.Context, cmd flightsql.GetXdbcTypeInfo, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
return m.flightInfoForCommand(desc, schema_ref.XdbcTypeInfo), nil
}
func (m *flightSqlScenarioTester) DoGetXdbcTypeInfo(context.Context, flightsql.GetXdbcTypeInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.XdbcTypeInfo, m.doGetForTestCase(schema_ref.XdbcTypeInfo), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoSqlInfo(ctx context.Context, cmd flightsql.GetSqlInfo, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if len(cmd.GetInfo()) == 2 {
// integration test for the protocol messages
if err := assertEq(int(2), len(cmd.GetInfo())); err != nil {
return nil, err
}
if err := assertEq(flightsql.SqlInfoFlightSqlServerName, flightsql.SqlInfo(cmd.GetInfo()[0])); err != nil {
return nil, err
}
if err := assertEq(flightsql.SqlInfoFlightSqlServerReadOnly, flightsql.SqlInfo(cmd.GetInfo()[1])); err != nil {
return nil, err
}
return m.flightInfoForCommand(desc, schema_ref.SqlInfo), nil
}
// integration test for the values themselves
return m.BaseServer.GetFlightInfoSqlInfo(ctx, cmd, desc)
}
func (m *flightSqlScenarioTester) DoGetSqlInfo(ctx context.Context, cmd flightsql.GetSqlInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) {
if len(cmd.GetInfo()) == 2 {
return schema_ref.SqlInfo, m.doGetForTestCase(schema_ref.SqlInfo), nil
}
return m.BaseServer.DoGetSqlInfo(ctx, cmd)
}
func (m *flightSqlScenarioTester) GetFlightInfoSchemas(_ context.Context, cmd flightsql.GetDBSchemas, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq("catalog", cmd.GetCatalog()); err != nil {
return nil, err
}
if err := assertEq("db_schema_filter_pattern", cmd.GetDBSchemaFilterPattern()); err != nil {
return nil, err
}
return m.flightInfoForCommand(desc, schema_ref.DBSchemas), nil
}
func (m *flightSqlScenarioTester) DoGetDBSchemas(context.Context, flightsql.GetDBSchemas) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.DBSchemas, m.doGetForTestCase(schema_ref.DBSchemas), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoTables(_ context.Context, cmd flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq("catalog", cmd.GetCatalog()); err != nil {
return nil, err
}
if err := assertEq("db_schema_filter_pattern", cmd.GetDBSchemaFilterPattern()); err != nil {
return nil, err
}
if err := assertEq("table_filter_pattern", cmd.GetTableNameFilterPattern()); err != nil {
return nil, err
}
if err := assertEq(int(2), len(cmd.GetTableTypes())); err != nil {
return nil, err
}
if err := assertEq("table", cmd.GetTableTypes()[0]); err != nil {
return nil, err
}
if err := assertEq("view", cmd.GetTableTypes()[1]); err != nil {
return nil, err
}
if err := assertEq(true, cmd.GetIncludeSchema()); err != nil {
return nil, err
}
return m.flightInfoForCommand(desc, schema_ref.TablesWithIncludedSchema), nil
}
func (m *flightSqlScenarioTester) DoGetTables(context.Context, flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.TablesWithIncludedSchema, m.doGetForTestCase(schema_ref.TablesWithIncludedSchema), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoTableTypes(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
return m.flightInfoForCommand(desc, schema_ref.TableTypes), nil
}
func (m *flightSqlScenarioTester) DoGetTableTypes(context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.TableTypes, m.doGetForTestCase(schema_ref.TableTypes), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoPrimaryKeys(_ context.Context, cmd flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq("catalog", cmd.Catalog); err != nil {
return nil, err
}
if err := assertEq("db_schema", cmd.DBSchema); err != nil {
return nil, err
}
if err := assertEq("table", cmd.Table); err != nil {
return nil, err
}
return m.flightInfoForCommand(desc, schema_ref.PrimaryKeys), nil
}
func (m *flightSqlScenarioTester) DoGetPrimaryKeys(context.Context, flightsql.TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.PrimaryKeys, m.doGetForTestCase(schema_ref.PrimaryKeys), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoExportedKeys(_ context.Context, cmd flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq("catalog", cmd.Catalog); err != nil {
return nil, err
}
if err := assertEq("db_schema", cmd.DBSchema); err != nil {
return nil, err
}
if err := assertEq("table", cmd.Table); err != nil {
return nil, err
}
return m.flightInfoForCommand(desc, schema_ref.ExportedKeys), nil
}
func (m *flightSqlScenarioTester) DoGetExportedKeys(context.Context, flightsql.TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.ExportedKeys, m.doGetForTestCase(schema_ref.ExportedKeys), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoImportedKeys(_ context.Context, cmd flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq("catalog", cmd.Catalog); err != nil {
return nil, err
}
if err := assertEq("db_schema", cmd.DBSchema); err != nil {
return nil, err
}
if err := assertEq("table", cmd.Table); err != nil {
return nil, err
}
return m.flightInfoForCommand(desc, schema_ref.ImportedKeys), nil
}
func (m *flightSqlScenarioTester) DoGetImportedKeys(context.Context, flightsql.TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.ImportedKeys, m.doGetForTestCase(schema_ref.ImportedKeys), nil
}
func (m *flightSqlScenarioTester) GetFlightInfoCrossReference(_ context.Context, cmd flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if err := assertEq("pk_catalog", cmd.PKRef.Catalog); err != nil {
return nil, err
}
if err := assertEq("pk_db_schema", cmd.PKRef.DBSchema); err != nil {
return nil, err
}
if err := assertEq("pk_table", cmd.PKRef.Table); err != nil {
return nil, err
}
if err := assertEq("fk_catalog", cmd.FKRef.Catalog); err != nil {
return nil, err
}
if err := assertEq("fk_db_schema", cmd.FKRef.DBSchema); err != nil {
return nil, err
}
if err := assertEq("fk_table", cmd.FKRef.Table); err != nil {
return nil, err
}
return m.flightInfoForCommand(desc, schema_ref.TableTypes), nil
}
func (m *flightSqlScenarioTester) DoGetCrossReference(context.Context, flightsql.CrossTableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) {
return schema_ref.CrossReference, m.doGetForTestCase(schema_ref.CrossReference), nil
}
func (m *flightSqlScenarioTester) DoPutCommandStatementUpdate(_ context.Context, cmd flightsql.StatementUpdate) (int64, error) {
if err := assertEq("UPDATE STATEMENT", cmd.GetQuery()); err != nil {
return 0, err
}
if len(cmd.GetTransactionId()) == 0 {
return updateStatementExpectedRows, nil
}
return updateStatementWithTransactionExpectedRows, nil
}
func (m *flightSqlScenarioTester) DoPutCommandSubstraitPlan(_ context.Context, cmd flightsql.StatementSubstraitPlan) (int64, error) {
if err := assertEq([]byte(substraitPlanText), cmd.GetPlan().Plan); err != nil {
return 0, fmt.Errorf("%w: wrong plan for DoPutCommandSubstraitPlan", err)
}
if err := assertEq(substraitPlanVersion, cmd.GetPlan().Version); err != nil {
return 0, fmt.Errorf("%w: unexpected version in DoPutCommandSubstraitPlan", err)
}
if len(cmd.GetTransactionId()) == 0 {
return updateStatementExpectedRows, nil
}
return updateStatementWithTransactionExpectedRows, nil
}
func (m *flightSqlScenarioTester) CreatePreparedStatement(_ context.Context, request flightsql.ActionCreatePreparedStatementRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) {
switch request.GetQuery() {
case "SELECT PREPARED STATEMENT", "UPDATE PREPARED STATEMENT":
default:
return res, fmt.Errorf("%w: unexpected query %s", arrow.ErrInvalid, request.GetQuery())
}
handle := request.GetQuery()
if len(request.GetTransactionId()) != 0 {
handle += " WITH TXN"
}
res.Handle = []byte(handle + " HANDLE")
return
}
func (m *flightSqlScenarioTester) CreatePreparedSubstraitPlan(_ context.Context, request flightsql.ActionCreatePreparedSubstraitPlanRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) {
if err := assertEq([]byte(substraitPlanText), request.GetPlan().Plan); err != nil {
return res, fmt.Errorf("%w: wrong plan for CreatePreparedSubstraitPlan", err)
}
if err := assertEq(substraitPlanVersion, request.GetPlan().Version); err != nil {
return res, fmt.Errorf("%w: unexpected version in DoPutCommandSubstraitPlan", err)
}
if len(request.GetTransactionId()) == 0 {
res.Handle = []byte("PLAN HANDLE")
} else {
res.Handle = []byte("PLAN WITH TXN HANDLE")
}
return
}
func (m *flightSqlScenarioTester) ClosePreparedStatement(_ context.Context, request flightsql.ActionClosePreparedStatementRequest) error {
switch string(request.GetPreparedStatementHandle()) {
case "SELECT PREPARED STATEMENT HANDLE",
"UPDATE PREPARED STATEMENT HANDLE",
"PLAN HANDLE",
"SELECT PREPARED STATEMENT WITH TXN HANDLE",
"UPDATE PREPARED STATEMENT WITH TXN HANDLE",
"PLAN WITH TXN HANDLE":
default:
return fmt.Errorf("%w: invalid handle for ClosePreparedStatement: %s",
arrow.ErrInvalid, string(request.GetPreparedStatementHandle()))
}
return nil
}
func (m *flightSqlScenarioTester) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) ([]byte, error) {
switch string(cmd.GetPreparedStatementHandle()) {
case "SELECT PREPARED STATEMENT HANDLE",
"SELECT PREPARED STATEMENT WITH TXN HANDLE",
"PLAN HANDLE", "PLAN WITH TXN HANDLE":
actualSchema := rdr.Schema()
return cmd.GetPreparedStatementHandle(), assertEq(true, actualSchema.Equal(getQuerySchema()))
}
return cmd.GetPreparedStatementHandle(), fmt.Errorf("%w: handle for DoPutPreparedStatementQuery '%s'",
arrow.ErrInvalid, string(cmd.GetPreparedStatementHandle()))
}
func (m *flightSqlScenarioTester) DoPutPreparedStatementUpdate(_ context.Context, cmd flightsql.PreparedStatementUpdate, _ flight.MessageReader) (int64, error) {
switch string(cmd.GetPreparedStatementHandle()) {
case "UPDATE PREPARED STATEMENT HANDLE", "PLAN HANDLE":
return updatePreparedStatementExpectedRows, nil
case "UPDATE PREPARED STATEMENT WITH TXN HANDLE", "PLAN WITH TXN HANDLE":
return updatePreparedStatementWithTransactionExpectedRows, nil
}
return 0, fmt.Errorf("%w: handle for DoPutPreparedStatementUpdate '%s'",
arrow.ErrInvalid, string(cmd.GetPreparedStatementHandle()))
}
func (m *flightSqlScenarioTester) BeginSavepoint(_ context.Context, request flightsql.ActionBeginSavepointRequest) ([]byte, error) {
if err := assertEq(savepointName, request.GetName()); err != nil {
return nil, fmt.Errorf("%w: unexpected savepoint name in BeginSavepoint", err)
}
if err := assertEq([]byte(transactionID), request.GetTransactionId()); err != nil {
return nil, fmt.Errorf("%w: unexpected transaction ID in BeginSavepoint", err)
}
return []byte(savepointID), nil
}
func (m *flightSqlScenarioTester) BeginTransaction(context.Context, flightsql.ActionBeginTransactionRequest) ([]byte, error) {
return []byte(transactionID), nil
}
func (m *flightSqlScenarioTester) CancelFlightInfo(_ context.Context, request *flight.CancelFlightInfoRequest) (flight.CancelFlightInfoResult, error) {
result := flight.CancelFlightInfoResult{Status: flight.CancelStatusUnspecified}
if err := assertEq(1, len(request.Info.Endpoint)); err != nil {
return result, fmt.Errorf("%w: expected 1 endpoint for CancelQuery", err)
}
endpoint := request.Info.Endpoint[0]
tkt, err := flightsql.GetStatementQueryTicket(endpoint.Ticket)
if err != nil {
return result, err
}
if err := assertEq([]byte("PLAN HANDLE"), tkt.GetStatementHandle()); err != nil {
return result, fmt.Errorf("%w: unexpected ticket in CancelQuery", err)
}
result.Status = flight.CancelStatusCancelled
return result, nil
}
func (m *flightSqlScenarioTester) EndSavepoint(_ context.Context, request flightsql.ActionEndSavepointRequest) error {
switch request.GetAction() {
case flightsql.EndSavepointRelease, flightsql.EndSavepointRollback:
if err := assertEq([]byte(savepointID), request.GetSavepointId()); err != nil {
return fmt.Errorf("%w: unexpected savepoint ID in EndSavepoint", err)
}
return nil
}
return fmt.Errorf("%w: unknown action %v", arrow.ErrInvalid, request.GetAction())
}
func (m *flightSqlScenarioTester) EndTransaction(_ context.Context, request flightsql.ActionEndTransactionRequest) error {
switch request.GetAction() {
case flightsql.EndTransactionCommit, flightsql.EndTransactionRollback:
if err := assertEq([]byte(transactionID), request.GetTransactionId()); err != nil {
return fmt.Errorf("%w: unexpected transaction ID in EndTransaction", err)
}
return nil
}
return fmt.Errorf("%w: unknown action %v", arrow.ErrInvalid, request.GetAction())
}
// schema to be returned for mocking the statement/prepared statement results
func getQuerySchema() *arrow.Schema {
return arrow.NewSchema([]arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true,
Metadata: *flightsql.NewColumnMetadataBuilder().
TableName("test").
IsAutoIncrement(true).
IsCaseSensitive(false).
TypeName("type_test").
SchemaName("schema_test").
IsSearchable(true).
CatalogName("catalog_test").
Precision(100).
Remarks("test column").
Build().Data}}, nil)
}
func getQueryWithTransactionSchema() *arrow.Schema {
return arrow.NewSchema([]arrow.Field{
{Name: "pkey", Type: arrow.PrimitiveTypes.Int32, Nullable: true,
Metadata: *flightsql.NewColumnMetadataBuilder().
TableName("test").
IsAutoIncrement(true).
IsCaseSensitive(false).
TypeName("type_test").
SchemaName("schema_test").
IsSearchable(true).
CatalogName("catalog_test").
Remarks("test column").
Precision(100).Build().Data}}, nil)
}
const (
substraitPlanText = "plan"
substraitPlanVersion = "version"
selectStatement = "SELECT STATEMENT"
savepointID = "savepoint_id"
savepointName = "savepoint_name"
transactionID = "transaction_id"
)
var substraitPlan = flightsql.SubstraitPlan{
Plan: []byte(substraitPlanText), Version: substraitPlanVersion}
type flightSqlExtensionScenarioTester struct {
flightSqlScenarioTester
}
func (m *flightSqlExtensionScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flightsql.NewClient(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
if err := m.ValidateMetadataRetrieval(client); err != nil {
return err
}
if err := m.ValidateStatementExecution(client); err != nil {
return err
}
if err := m.ValidatePreparedStatementExecution(client); err != nil {
return err
}
return m.ValidateTransactions(client)
}
func (m *flightSqlExtensionScenarioTester) ValidateMetadataRetrieval(client *flightsql.Client) error {
sqlInfo := []flightsql.SqlInfo{
flightsql.SqlInfoFlightSqlServerSql,
flightsql.SqlInfoFlightSqlServerSubstrait,
flightsql.SqlInfoFlightSqlServerSubstraitMinVersion,
flightsql.SqlInfoFlightSqlServerSubstraitMaxVersion,
flightsql.SqlInfoFlightSqlServerTransaction,
flightsql.SqlInfoFlightSqlServerCancel,
flightsql.SqlInfoFlightSqlServerStatementTimeout,
flightsql.SqlInfoFlightSqlServerTransactionTimeout,
}
ctx := context.Background()
info, err := client.GetSqlInfo(ctx, sqlInfo)
if err != nil {
return err
}
rdr, err := client.DoGet(ctx, info.Endpoint[0].Ticket)
if err != nil {
return err
}
defer rdr.Release()
actualSchema := rdr.Schema()
if !schema_ref.SqlInfo.Equal(actualSchema) {
return fmt.Errorf("%w: schemas did not match. expected: %s\n got: %s",
arrow.ErrInvalid, schema_ref.SqlInfo, actualSchema)
}
infoValues := make(flightsql.SqlInfoResultMap)
for rdr.Next() {
rec := rdr.Record()
names, values := rec.Column(0).(*array.Uint32), rec.Column(1).(*array.DenseUnion)
for i := 0; i < int(rec.NumRows()); i++ {
code := names.Value(i)
if _, ok := infoValues[code]; ok {
return fmt.Errorf("%w: duplicate SqlInfo value %d", arrow.ErrInvalid, code)
}
switch values.TypeCode(i) {
case 0: // string
infoValues[code] = values.Field(0).(*array.String).
Value(int(values.ValueOffset(i)))
case 1: // bool
infoValues[code] = values.Field(1).(*array.Boolean).
Value(int(values.ValueOffset(i)))
case 2: // int64
infoValues[code] = values.Field(2).(*array.Int64).
Value(int(values.ValueOffset(i)))
case 3: // int32
infoValues[code] = values.Field(3).(*array.Int32).
Value(int(values.ValueOffset(i)))
default:
return fmt.Errorf("%w: decoding SqlInfoResult of type code %d",
arrow.ErrNotImplemented, values.TypeCode(i))
}
}
}
if rdr.Err() != nil {
return rdr.Err()
}
for k, v := range infoValues {
switch k {
case uint32(flightsql.SqlInfoFlightSqlServerSql):
if err := assertEq(false, v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
case uint32(flightsql.SqlInfoFlightSqlServerSubstrait):
if err := assertEq(true, v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
case uint32(flightsql.SqlInfoFlightSqlServerSubstraitMinVersion):
if err := assertEq("min_version", v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
case uint32(flightsql.SqlInfoFlightSqlServerSubstraitMaxVersion):
if err := assertEq("max_version", v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
case uint32(flightsql.SqlInfoFlightSqlServerTransaction):
if err := assertEq(int32(flightsql.SqlTransactionSavepoint), v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
case uint32(flightsql.SqlInfoFlightSqlServerCancel):
if err := assertEq(true, v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
case uint32(flightsql.SqlInfoFlightSqlServerStatementTimeout):
if err := assertEq(int32(42), v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
case uint32(flightsql.SqlInfoFlightSqlServerTransactionTimeout):
if err := assertEq(int32(7), v); err != nil {
return fmt.Errorf("%w: %v did not match", err, k)
}
}
}
return nil
}
func (m *flightSqlExtensionScenarioTester) ValidateStatementExecution(client *flightsql.Client) error {
ctx := context.Background()
info, err := client.ExecuteSubstrait(ctx, substraitPlan)
if err != nil {
return err
}
if err := m.validate(getQuerySchema(), info, client); err != nil {
return err
}
schema, err := client.GetExecuteSubstraitSchema(ctx, substraitPlan)
if err != nil {
return err
}
if err := m.validateSchema(getQuerySchema(), schema); err != nil {
return err
}
info, err = client.ExecuteSubstrait(ctx, substraitPlan)
if err != nil {
return err
}
//nolint:staticcheck,SA1019 for backward compatibility
cancelResult, err := client.CancelQuery(ctx, info)
if err != nil {
return err
}
if err := assertEq(flightsql.CancelResultCancelled, cancelResult); err != nil {
return fmt.Errorf("%w: wrong cancel result", err)
}
updatedRows, err := client.ExecuteSubstraitUpdate(ctx, substraitPlan)
if err != nil {
return err
}
if err := assertEq(updateStatementExpectedRows, updatedRows); err != nil {
return fmt.Errorf("%w: wrong number of updated rows for ExecuteSubstraitUpdate", err)
}
return nil
}
func (m *flightSqlExtensionScenarioTester) ValidatePreparedStatementExecution(client *flightsql.Client) error {
arr, _, _ := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64, strings.NewReader("[1]"))
defer arr.Release()
params := array.NewRecord(getQuerySchema(), []arrow.Array{arr}, 1)
defer params.Release()
ctx := context.Background()
stmt, err := client.PrepareSubstrait(ctx, substraitPlan)
if err != nil {
return err
}
stmt.SetParameters(params)
info, err := stmt.Execute(ctx)
if err != nil {
return err
}
if err := m.validate(getQuerySchema(), info, client); err != nil {
return err
}
schema, err := stmt.GetSchema(ctx)
if err != nil {
return err
}
if err := m.validateSchema(getQuerySchema(), schema); err != nil {
return err
}
if err := stmt.Close(ctx); err != nil {
return err
}
updateStmt, err := client.PrepareSubstrait(ctx, substraitPlan)
if err != nil {
return err
}
updatedRows, err := updateStmt.ExecuteUpdate(ctx)
if err != nil {
return err
}
if err := assertEq(updatePreparedStatementExpectedRows, updatedRows); err != nil {
return err
}
return updateStmt.Close(ctx)
}
func (m *flightSqlExtensionScenarioTester) ValidateTransactions(client *flightsql.Client) error {
ctx := context.Background()
txn, err := client.BeginTransaction(ctx)
if err != nil {
return err
}
if err := assertEq([]byte(transactionID), []byte(txn.ID())); err != nil {
return err
}
sp, err := txn.BeginSavepoint(ctx, savepointName)
if err != nil {
return err
}
if err := assertEq([]byte(savepointID), []byte(sp)); err != nil {
return err
}
info, err := txn.Execute(ctx, selectStatement)
if err != nil {
return err
}
if err := m.validate(getQueryWithTransactionSchema(), info, client); err != nil {
return err
}
info, err = txn.ExecuteSubstrait(ctx, substraitPlan)
if err != nil {
return err
}
if err := m.validate(getQueryWithTransactionSchema(), info, client); err != nil {
return err
}
schema, err := txn.GetExecuteSchema(ctx, selectStatement)
if err != nil {
return err
}
if err := m.validateSchema(getQueryWithTransactionSchema(), schema); err != nil {
return err
}
schema, err = txn.GetExecuteSubstraitSchema(ctx, substraitPlan)
if err != nil {
return err
}
if err := m.validateSchema(getQueryWithTransactionSchema(), schema); err != nil {
return err
}
updated, err := txn.ExecuteUpdate(ctx, "UPDATE STATEMENT")
if err != nil {
return err
}
if err := assertEq(updateStatementWithTransactionExpectedRows, updated); err != nil {
return err
}
updated, err = txn.ExecuteSubstraitUpdate(ctx, substraitPlan)
if err != nil {
return err
}
if err := assertEq(updateStatementWithTransactionExpectedRows, updated); err != nil {
return err
}
arr, _, _ := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64, strings.NewReader("[1]"))
defer arr.Release()
params := array.NewRecord(getQuerySchema(), []arrow.Array{arr}, 1)
defer params.Release()
prepared, err := txn.Prepare(ctx, "SELECT PREPARED STATEMENT")
if err != nil {
return err
}
prepared.SetParameters(params)
info, err = prepared.Execute(ctx)
if err != nil {
return err
}
if err := m.validate(getQueryWithTransactionSchema(), info, client); err != nil {
return err
}
schema, err = prepared.GetSchema(ctx)
if err != nil {
return err
}
if err := m.validateSchema(getQueryWithTransactionSchema(), schema); err != nil {
return err
}
if err := prepared.Close(ctx); err != nil {
return err
}
prepared, err = txn.PrepareSubstrait(ctx, substraitPlan)
if err != nil {
return err
}
prepared.SetParameters(params)
info, err = prepared.Execute(ctx)
if err != nil {
return err
}
if err := m.validate(getQueryWithTransactionSchema(), info, client); err != nil {
return err
}
schema, err = prepared.GetSchema(ctx)
if err != nil {
return err
}
if err := m.validateSchema(getQueryWithTransactionSchema(), schema); err != nil {
return err
}
if err := prepared.Close(ctx); err != nil {
return err
}
prepared, err = txn.Prepare(ctx, "UPDATE PREPARED STATEMENT")
if err != nil {
return err
}
updated, err = prepared.ExecuteUpdate(ctx)
if err != nil {
return err
}
if err := assertEq(updatePreparedStatementWithTransactionExpectedRows, updated); err != nil {
return err
}
if err := prepared.Close(ctx); err != nil {
return err
}
prepared, err = txn.PrepareSubstrait(ctx, substraitPlan)
if err != nil {
return err
}
updated, err = prepared.ExecuteUpdate(ctx)
if err != nil {
return err
}
if err := assertEq(updatePreparedStatementWithTransactionExpectedRows, updated); err != nil {
return err
}
if err := prepared.Close(ctx); err != nil {
return err
}
if err := txn.RollbackSavepoint(ctx, sp); err != nil {
return err
}
sp2, err := txn.BeginSavepoint(ctx, savepointName)
if err != nil {
return err
}
if err := assertEq([]byte(savepointID), []byte(sp2)); err != nil {
return err
}
if err := txn.ReleaseSavepoint(ctx, sp); err != nil {
return err
}
if err := txn.Commit(ctx); err != nil {
return err
}
txn, err = client.BeginTransaction(ctx)
if err != nil {
return err
}
if err := assertEq([]byte(transactionID), []byte(txn.ID())); err != nil {
return err
}
return txn.Rollback(ctx)
}
type sessionOptionsScenarioTester struct {
flightsql.BaseServer
}
func (tester *sessionOptionsScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
flight.CreateServerMiddleware(session.NewServerSessionMiddleware(nil)),
})
srv.RegisterFlightService(flightsql.NewFlightServer(tester))
initServer(port, srv)
return srv
}
func (tester *sessionOptionsScenarioTester) SetSessionOptions(ctx context.Context, req *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) {
session, err := session.GetSessionFromContext(ctx)
if err != nil {
return nil, err
}
errors := make(map[string]*flight.SetSessionOptionsResultError)
for key, val := range req.GetSessionOptions() {
if key == "lol_invalid" {
errors[key] = &flight.SetSessionOptionsResultError{Value: flight.SetSessionOptionsResultErrorInvalidName}
continue
}
if val.GetStringValue() == "lol_invalid" {
errors[key] = &flight.SetSessionOptionsResultError{Value: flight.SetSessionOptionsResultErrorInvalidValue}
continue
}
session.SetSessionOption(key, val)
}
return &flight.SetSessionOptionsResult{Errors: errors}, nil
}
func (tester *sessionOptionsScenarioTester) GetSessionOptions(ctx context.Context, req *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) {
session, err := session.GetSessionFromContext(ctx)
if err != nil {
return nil, err
}
return &flight.GetSessionOptionsResult{SessionOptions: session.GetSessionOptions()}, nil
}
func (tester *sessionOptionsScenarioTester) CloseSession(ctx context.Context, req *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) {
session, err := session.GetSessionFromContext(ctx)
if err != nil {
return nil, err
}
if err = session.Close(); err != nil {
return nil, err
}
return &flight.CloseSessionResult{Status: flight.CloseSessionResultClosed}, nil
}
func (tester *sessionOptionsScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
middleware := []flight.ClientMiddleware{
flight.NewClientCookieMiddleware(),
}
client, err := flight.NewClientWithMiddleware(addr, nil, middleware, opts...)
if err != nil {
return err
}
defer client.Close()
// Run validations in order. We are changing session state in each step, so order is made explicit.
ctx := context.Background()
if err = tester.ValidateFirstGetSessionOptions(ctx, client); err != nil {
return err
}
if err = tester.ValidateSecondSetSessionOptions(ctx, client); err != nil {
return err
}
if err = tester.ValidateThirdGetSessionOptions(ctx, client); err != nil {
return err
}
if err = tester.ValidateFourthRemoveOption(ctx, client); err != nil {
return err
}
if err = tester.ValidateFifthGetSessionOptions(ctx, client); err != nil {
return err
}
if err = tester.ValidateSixthCloseSession(ctx, client); err != nil {
return err
}
// C++ impl currently fails with "Invalid or expired arrow_flight_session_id cookie", likely related to GH-39791
// if err = tester.ValidateSeventhGetSessionOptions(ctx, client); err != nil {
// return err
// }
return nil
}
func (tester *sessionOptionsScenarioTester) ValidateFirstGetSessionOptions(ctx context.Context, client flight.Client) error {
res, err := client.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{})
if err != nil {
return err
}
opts := res.GetSessionOptions()
if len(opts) != 0 {
return fmt.Errorf("expected new session to be empty, but found %d options already set", len(opts))
}
return nil
}
func (tester *sessionOptionsScenarioTester) ValidateSecondSetSessionOptions(ctx context.Context, client flight.Client) error {
opts, err := flight.NewSessionOptionValues(map[string]any{
"foolong": int64(123),
"bardouble": 456.0,
"lol_invalid": "this won't get set",
"key_with_invalid_value": "lol_invalid",
"big_ol_string_list": []string{"a", "b", "sea", "dee", " ", " ", "geee", "(づ。◕‿‿◕。)づ"},
})
if err != nil {
return err
}
res, err := client.SetSessionOptions(ctx, &flight.SetSessionOptionsRequest{SessionOptions: opts})
if err != nil {
return err
}
expectedErrs := map[string]*flight.SetSessionOptionsResultError{
"lol_invalid": {Value: flight.SetSessionOptionsResultErrorInvalidName},
"key_with_invalid_value": {Value: flight.SetSessionOptionsResultErrorInvalidValue},
}
errs := res.GetErrors()
if len(errs) != len(expectedErrs) {
return fmt.Errorf("errors expected: %d, got: %d", len(expectedErrs), len(errs))
}
for key, val := range errs {
if !reflect.DeepEqual(val, expectedErrs[key]) {
return fmt.Errorf("error mismatch for key %s. expected: %s, got: %s", key, expectedErrs[key], val)
}
}
return nil
}
func (tester *sessionOptionsScenarioTester) ValidateThirdGetSessionOptions(ctx context.Context, client flight.Client) error {
res, err := client.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{})
if err != nil {
return err
}
expectedOpts, err := flight.NewSessionOptionValues(map[string]any{
"foolong": int64(123),
"bardouble": 456.0,
"big_ol_string_list": []string{"a", "b", "sea", "dee", " ", " ", "geee", "(づ。◕‿‿◕。)づ"},
})
if err != nil {
return err
}
opts := res.GetSessionOptions()
if len(opts) != len(expectedOpts) {
return fmt.Errorf("options expected: %d, got: %d", len(expectedOpts), len(opts))
}
for key, val := range opts {
if !reflect.DeepEqual(val, expectedOpts[key]) {
return fmt.Errorf("session options mismatch for key %s. expected: %s, got: %s", key, expectedOpts[key], val)
}
}
return nil
}
func (tester *sessionOptionsScenarioTester) ValidateFourthRemoveOption(ctx context.Context, client flight.Client) error {
opts, err := flight.NewSessionOptionValues(map[string]any{
"foolong": nil,
})
if err != nil {
return err
}
res, err := client.SetSessionOptions(ctx, &flight.SetSessionOptionsRequest{SessionOptions: opts})
if err != nil {
return err
}
errs := res.GetErrors()
if len(errs) != 0 {
return fmt.Errorf("errors expected: %d, got: %d", 0, len(errs))
}
return nil
}
func (tester *sessionOptionsScenarioTester) ValidateFifthGetSessionOptions(ctx context.Context, client flight.Client) error {
res, err := client.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{})
if err != nil {
return err
}
expectedOpts, err := flight.NewSessionOptionValues(map[string]any{
"bardouble": 456.0,
"big_ol_string_list": []string{"a", "b", "sea", "dee", " ", " ", "geee", "(づ。◕‿‿◕。)づ"},
})
if err != nil {
return err
}
opts := res.GetSessionOptions()
if len(opts) != len(expectedOpts) {
return fmt.Errorf("options expected: %d, got: %d", len(expectedOpts), len(opts))
}
for key, val := range opts {
if !reflect.DeepEqual(val, expectedOpts[key]) {
return fmt.Errorf("session options mismatch for key %s. expected: %s, got: %s", key, expectedOpts[key], val)
}
}
return nil
}
func (tester *sessionOptionsScenarioTester) ValidateSixthCloseSession(ctx context.Context, client flight.Client) error {
res, err := client.CloseSession(ctx, &flight.CloseSessionRequest{})
if err != nil {
return err
}
if res.GetStatus() != flight.CloseSessionResultClosed {
return fmt.Errorf("expected session to successfully close, but found status: %s", res.GetStatus())
}
return nil
}
func (tester *sessionOptionsScenarioTester) ValidateSeventhGetSessionOptions(ctx context.Context, client flight.Client) error {
res, err := client.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{})
if err != nil {
return err
}
opts := res.GetSessionOptions()
if len(opts) != 0 {
return fmt.Errorf("expected new session to be empty, but found %d options already set", len(opts))
}
return nil
}
type flightSqlIngestionScenarioTester struct {
flightsql.BaseServer
}
func (m *flightSqlIngestionScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil)
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerBulkIngestion, true)
m.RegisterSqlInfo(flightsql.SqlInfoFlightSqlServerIngestTransactionsSupported, true)
srv.RegisterFlightService(flightsql.NewFlightServer(m))
initServer(port, srv)
return srv
}
func (m *flightSqlIngestionScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flightsql.NewClient(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
return m.ValidateIngestion(client)
}
func (m *flightSqlIngestionScenarioTester) ValidateIngestion(client *flightsql.Client) error {
ctx := context.Background()
opts := getIngestOptions()
ingestResult, err := client.ExecuteIngest(ctx, getIngestRecords(), opts)
if err != nil {
return err
}
if ingestResult != ingestStatementExpectedRows {
return fmt.Errorf("expected ingest return %d got %d", ingestStatementExpectedRows, ingestResult)
}
return nil
}
func (m *flightSqlIngestionScenarioTester) DoPutCommandStatementIngest(ctx context.Context, cmd flightsql.StatementIngest, rdr flight.MessageReader) (int64, error) {
expectedSchema := getIngestSchema()
expectedOpts := getIngestOptions()
if err := assertEq(expectedOpts.TableDefinitionOptions.IfExists, cmd.GetTableDefinitionOptions().IfExists); err != nil {
return 0, err
}
if err := assertEq(expectedOpts.TableDefinitionOptions.IfNotExist, cmd.GetTableDefinitionOptions().IfNotExist); err != nil {
return 0, err
}
if err := assertEq(expectedOpts.Table, cmd.GetTable()); err != nil {
return 0, err
}
if err := assertEq(*expectedOpts.Schema, cmd.GetSchema()); err != nil {
return 0, err
}
if err := assertEq(*expectedOpts.Catalog, cmd.GetCatalog()); err != nil {
return 0, err
}
if err := assertEq(expectedOpts.Temporary, cmd.GetTemporary()); err != nil {
return 0, err
}
if err := assertEq(expectedOpts.TransactionId, cmd.GetTransactionId()); err != nil {
return 0, err
}
if err := assertEq(expectedOpts.Options, cmd.GetOptions()); err != nil {
return 0, err
}
var nRecords int64
for rdr.Next() {
rec := rdr.Record()
nRecords += rec.NumRows()
if err := assertEq(true, expectedSchema.Equal(rec.Schema())); err != nil {
return 0, err
}
}
return nRecords, nil
}
// Options to assert before/after mocked ingest call
func getIngestOptions() *flightsql.ExecuteIngestOpts {
tableDefinitionOptions := flightsql.TableDefinitionOptions{
IfNotExist: flightsql.TableDefinitionOptionsTableNotExistOptionCreate,
IfExists: flightsql.TableDefinitionOptionsTableExistsOptionReplace,
}
table := "test_table"
schema := "test_schema"
catalog := "test_catalog"
temporary := true
transactionId := []byte("123")
options := map[string]string{
"key1": "val1",
"key2": "val2",
}
return &flightsql.ExecuteIngestOpts{
TableDefinitionOptions: &tableDefinitionOptions,
Table: table,
Schema: &schema,
Catalog: &catalog,
Temporary: temporary,
TransactionId: transactionId,
Options: options,
}
}
// Schema for ingest records; asserted on records received by handler
func getIngestSchema() *arrow.Schema {
return arrow.NewSchema([]arrow.Field{{Name: "test_field", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
}
// Prepare records for ingestion with known length and schema
func getIngestRecords() array.RecordReader {
schema := getIngestSchema()
arr := array.MakeArrayOfNull(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64, int(ingestStatementExpectedRows))
defer arr.Release()
rec := array.NewRecord(schema, []arrow.Array{arr}, ingestStatementExpectedRows)
defer rec.Release()
rdr, _ := array.NewRecordReader(schema, []arrow.Record{rec})
return rdr
}