arrow/flight/flightsql/server.go (1,019 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package flightsql import ( "context" "fmt" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" pb "github.com/apache/arrow-go/v18/arrow/flight/gen/flight" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" ) // the following interfaces wrap the Protobuf commands to avoid // exposing the Protobuf types themselves in the API. // StatementQuery represents a Sql Query type StatementQuery interface { GetQuery() string GetTransactionId() []byte } type statementSubstraitPlan struct { *pb.CommandStatementSubstraitPlan } func (s *statementSubstraitPlan) GetPlan() SubstraitPlan { var ( plan []byte version string ) if s.Plan != nil { plan = s.Plan.Plan version = s.Plan.Version } return SubstraitPlan{ Plan: plan, Version: version, } } type StatementSubstraitPlan interface { GetTransactionId() []byte GetPlan() SubstraitPlan } // StatementUpdate represents a SQL update query type StatementUpdate interface { GetQuery() string GetTransactionId() []byte } // StatementQueryTicket represents a request to execute a query type StatementQueryTicket interface { // GetStatementHandle returns the server-generated opaque // identifier for the query GetStatementHandle() []byte } func GetStatementQueryTicket(ticket *flight.Ticket) (result StatementQueryTicket, err error) { var anycmd anypb.Any if err = proto.Unmarshal(ticket.Ticket, &anycmd); err != nil { return } var out pb.TicketStatementQuery if err = anycmd.UnmarshalTo(&out); err != nil { return } result = &out return } // PreparedStatementQuery represents a prepared query statement type PreparedStatementQuery interface { // GetPreparedStatementHandle returns the server-generated opaque // identifier for the statement GetPreparedStatementHandle() []byte } // PreparedStatementUpdate represents a prepared update statement type PreparedStatementUpdate interface { // GetPreparedStatementHandle returns the server-generated opaque // identifier for the statement GetPreparedStatementHandle() []byte } // ActionClosePreparedStatementRequest represents a request to close // a prepared statement type ActionClosePreparedStatementRequest interface { // GetPreparedStatementHandle returns the server-generated opaque // identifier for the statement GetPreparedStatementHandle() []byte } // ActionCreatePreparedStatementRequest represents a request to construct // a new prepared statement type ActionCreatePreparedStatementRequest interface { GetQuery() string GetTransactionId() []byte } type ActionCreatePreparedSubstraitPlanRequest interface { GetPlan() SubstraitPlan GetTransactionId() []byte } type createPreparedSubstraitPlanReq struct { *pb.ActionCreatePreparedSubstraitPlanRequest } func (c *createPreparedSubstraitPlanReq) GetPlan() SubstraitPlan { var ( plan []byte version string ) if c.Plan != nil { plan = c.Plan.Plan version = c.Plan.Version } return SubstraitPlan{ Plan: plan, Version: version, } } // ActionCreatePreparedStatementResult is the result of creating a new // prepared statement, optionally including the dataset and parameter // schemas. type ActionCreatePreparedStatementResult struct { Handle []byte DatasetSchema *arrow.Schema ParameterSchema *arrow.Schema } type ActionBeginTransactionRequest interface{} type ActionBeginSavepointRequest interface { GetTransactionId() []byte GetName() string } type ActionBeginSavepointResult interface { GetSavepointId() []byte } type ActionBeginTransactionResult interface { GetTransactionId() []byte } type ActionCancelQueryRequest interface { GetInfo() *flight.FlightInfo } type cancelQueryRequest struct { info *flight.FlightInfo } func (c *cancelQueryRequest) GetInfo() *flight.FlightInfo { return c.info } type cancelQueryServer interface { CancelQuery(context.Context, ActionCancelQueryRequest) (CancelResult, error) } type ActionEndTransactionRequest interface { GetTransactionId() []byte GetAction() EndTransactionRequestType } type ActionEndSavepointRequest interface { GetSavepointId() []byte GetAction() EndSavepointRequestType } // StatementIngest represents a bulk ingestion request type StatementIngest interface { GetTableDefinitionOptions() *TableDefinitionOptions GetTable() string GetSchema() string GetCatalog() string GetTemporary() bool GetTransactionId() []byte GetOptions() map[string]string } type getXdbcTypeInfo struct { *pb.CommandGetXdbcTypeInfo } func (c *getXdbcTypeInfo) GetDataType() *int32 { return c.DataType } // GetXdbcTypeInfo represents a request for SQL Data Type information type GetXdbcTypeInfo interface { // GetDataType returns either nil (get for all types) // or a specific SQL type ID to fetch information about. GetDataType() *int32 } // GetSqlInfo represents a request for SQL Information type GetSqlInfo interface { // GetInfo returns a slice of SqlInfo ids to return information about GetInfo() []uint32 } type getDBSchemas struct { *pb.CommandGetDbSchemas } func (c *getDBSchemas) GetCatalog() *string { return c.Catalog } func (c *getDBSchemas) GetDBSchemaFilterPattern() *string { return c.DbSchemaFilterPattern } // GetDBSchemas represents a request for list of database schemas type GetDBSchemas interface { GetCatalog() *string GetDBSchemaFilterPattern() *string } type getTables struct { *pb.CommandGetTables } func (c *getTables) GetCatalog() *string { return c.Catalog } func (c *getTables) GetDBSchemaFilterPattern() *string { return c.DbSchemaFilterPattern } func (c *getTables) GetTableNameFilterPattern() *string { return c.TableNameFilterPattern } // GetTables represents a request to list the database's tables type GetTables interface { GetCatalog() *string GetDBSchemaFilterPattern() *string GetTableNameFilterPattern() *string GetTableTypes() []string GetIncludeSchema() bool } func packActionResult(msg proto.Message) (*pb.Result, error) { var ( anycmd anypb.Any err error ) if err = anycmd.MarshalFrom(msg); err != nil { return nil, fmt.Errorf("%w: unable to marshal final response", err) } ret := &pb.Result{} if ret.Body, err = proto.Marshal(&anycmd); err != nil { return nil, fmt.Errorf("%w: unable to marshal final response", err) } return ret, nil } // BaseServer must be embedded into any FlightSQL Server implementation // and provides default implementations of all methods returning an // unimplemented error if called. This allows consumers to gradually // implement methods as they want instead of requiring all consumers to // boilerplate the same "unimplemented" methods. // // The base implementation also contains handling for registering sql info // and serving it up in response to GetSqlInfo requests. type BaseServer struct { sqlInfoToResult SqlInfoResultMap // Alloc allows specifying a particular allocator to use for any // allocations done by the base implementation. // Will use memory.DefaultAllocator if nil Alloc memory.Allocator } func (BaseServer) mustEmbedBaseServer() {} // RegisterSqlInfo registers a specific result to return for a given sqlinfo // id. The result must be one of the following types: string, bool, int64, // int32, []string, or map[int32][]int32. // // Once registered, this value will be returned for any SqlInfo requests. func (b *BaseServer) RegisterSqlInfo(id SqlInfo, result interface{}) error { if b.sqlInfoToResult == nil { b.sqlInfoToResult = make(SqlInfoResultMap) } switch result.(type) { case string, bool, int64, int32, []string, map[int32][]int32: b.sqlInfoToResult[uint32(id)] = result default: return fmt.Errorf("invalid sql info type '%T' registered for id: %d", result, id) } return nil } func (BaseServer) GetFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoStatement not implemented") } func (BaseServer) GetFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoSubstraitPlan not implemented") } func (BaseServer) GetSchemaStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) { return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") } func (BaseServer) GetSchemaSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.SchemaResult, error) { return nil, status.Errorf(codes.Unimplemented, "GetSchemaSubstraitPlan not implemented") } func (BaseServer) DoGetStatement(context.Context, StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetStatement not implemented") } func (BaseServer) GetFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoPreparedStatement not implemented") } func (BaseServer) GetSchemaPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) { return nil, status.Errorf(codes.Unimplemented, "GetSchemaPreparedStatement not implemented") } func (BaseServer) DoGetPreparedStatement(context.Context, PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetPreparedStatement not implemented") } func (BaseServer) GetFlightInfoCatalogs(context.Context, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoCatalogs not implemented") } func (BaseServer) DoGetCatalogs(context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetCatalogs not implemented") } func (BaseServer) GetFlightInfoXdbcTypeInfo(context.Context, GetXdbcTypeInfo, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoXdbcTypeInfo not implemented") } func (BaseServer) DoGetXdbcTypeInfo(context.Context, GetXdbcTypeInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetXdbcTypeInfo not implemented") } // GetFlightInfoSqlInfo is a base implementation of GetSqlInfo by using any // registered sqlinfo (by calling RegisterSqlInfo). Will return an error // if there is no sql info registered, otherwise a FlightInfo for retrieving // the Sql info. func (b *BaseServer) GetFlightInfoSqlInfo(_ context.Context, _ GetSqlInfo, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { if len(b.sqlInfoToResult) == 0 { return nil, status.Error(codes.NotFound, "no sql information available") } if b.Alloc == nil { b.Alloc = memory.DefaultAllocator } return &flight.FlightInfo{ Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, FlightDescriptor: desc, TotalRecords: -1, TotalBytes: -1, Schema: flight.SerializeSchema(schema_ref.SqlInfo, b.Alloc), }, nil } // DoGetSqlInfo returns a flight stream containing the list of sqlinfo results func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd GetSqlInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) { if b.Alloc == nil { b.Alloc = memory.DefaultAllocator } bldr := array.NewRecordBuilder(b.Alloc, schema_ref.SqlInfo) defer bldr.Release() nameFieldBldr := bldr.Field(0).(*array.Uint32Builder) valFieldBldr := bldr.Field(1).(*array.DenseUnionBuilder) // doesn't take ownership, no calls to retain. so we don't need // extra releases. sqlInfoResultBldr := newSqlInfoResultBuilder(valFieldBldr) keys := cmd.GetInfo() // populate both the nameFieldBldr and the values for each // element on command.info. // valueFieldBldr is populated depending on the data type // since it's a dense union. The population for each // data type is handled by the sqlInfoResultBuilder. if len(keys) > 0 { for _, info := range keys { val, ok := b.sqlInfoToResult[info] if !ok { return nil, nil, status.Errorf(codes.NotFound, "no information for sql info number %d", info) } nameFieldBldr.Append(info) sqlInfoResultBldr.Append(val) } } else { for k, v := range b.sqlInfoToResult { nameFieldBldr.Append(k) sqlInfoResultBldr.Append(v) } } batch := bldr.NewRecord() defer batch.Release() debug.Assert(int(batch.NumRows()) == len(cmd.GetInfo()), "too many rows added to SqlInfo result") ch := make(chan flight.StreamChunk) rdr, err := array.NewRecordReader(schema_ref.SqlInfo, []arrow.Record{batch}) if err != nil { return nil, nil, status.Errorf(codes.Internal, "error producing record response: %s", err.Error()) } // StreamChunksFromReader will call release on the reader when done go flight.StreamChunksFromReader(rdr, ch) return schema_ref.SqlInfo, ch, nil } func (BaseServer) GetFlightInfoSchemas(context.Context, GetDBSchemas, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoSchemas not implemented") } func (BaseServer) DoGetDBSchemas(context.Context, GetDBSchemas) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetDBSchemas not implemented") } func (BaseServer) GetFlightInfoTables(context.Context, GetTables, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoTables not implemented") } func (BaseServer) DoGetTables(context.Context, GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetTables not implemented") } func (BaseServer) GetFlightInfoTableTypes(context.Context, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoTableTypes not implemented") } func (BaseServer) DoGetTableTypes(context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetTableTypes not implemented") } func (BaseServer) GetFlightInfoPrimaryKeys(context.Context, TableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Error(codes.Unimplemented, "GetFlightInfoPrimaryKeys not implemented") } func (BaseServer) DoGetPrimaryKeys(context.Context, TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetPrimaryKeys not implemented") } func (BaseServer) GetFlightInfoExportedKeys(context.Context, TableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Error(codes.Unimplemented, "GetFlightInfoExportedKeys not implemented") } func (BaseServer) DoGetExportedKeys(context.Context, TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetExportedKeys not implemented") } func (BaseServer) GetFlightInfoImportedKeys(context.Context, TableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Error(codes.Unimplemented, "GetFlightInfoImportedKeys not implemented") } func (BaseServer) DoGetImportedKeys(context.Context, TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetImportedKeys not implemented") } func (BaseServer) GetFlightInfoCrossReference(context.Context, CrossTableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) { return nil, status.Error(codes.Unimplemented, "GetFlightInfoCrossReference not implemented") } func (BaseServer) DoGetCrossReference(context.Context, CrossTableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetCrossReference not implemented") } func (BaseServer) CreatePreparedStatement(context.Context, ActionCreatePreparedStatementRequest) (res ActionCreatePreparedStatementResult, err error) { return res, status.Error(codes.Unimplemented, "CreatePreparedStatement not implemented") } func (BaseServer) CreatePreparedSubstraitPlan(context.Context, ActionCreatePreparedSubstraitPlanRequest) (res ActionCreatePreparedStatementResult, err error) { return res, status.Error(codes.Unimplemented, "CreatePreparedSubstraitPlan not implemented") } func (BaseServer) ClosePreparedStatement(context.Context, ActionClosePreparedStatementRequest) error { return status.Error(codes.Unimplemented, "ClosePreparedStatement not implemented") } func (BaseServer) DoPutCommandStatementUpdate(context.Context, StatementUpdate) (int64, error) { return 0, status.Error(codes.Unimplemented, "DoPutCommandStatementUpdate not implemented") } func (BaseServer) DoPutCommandSubstraitPlan(context.Context, StatementSubstraitPlan) (int64, error) { return 0, status.Error(codes.Unimplemented, "DoPutCommandSubstraitPlan not implemented") } func (BaseServer) DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error) { return nil, status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery not implemented") } func (BaseServer) DoPutPreparedStatementUpdate(context.Context, PreparedStatementUpdate, flight.MessageReader) (int64, error) { return 0, status.Error(codes.Unimplemented, "DoPutPreparedStatementUpdate not implemented") } func (BaseServer) DoPutCommandStatementIngest(context.Context, StatementIngest, flight.MessageReader) (int64, error) { return 0, status.Error(codes.Unimplemented, "DoPutCommandStatementIngest not implemented") } func (BaseServer) BeginTransaction(context.Context, ActionBeginTransactionRequest) ([]byte, error) { return nil, status.Error(codes.Unimplemented, "BeginTransaction not implemented") } func (BaseServer) BeginSavepoint(context.Context, ActionBeginSavepointRequest) ([]byte, error) { return nil, status.Error(codes.Unimplemented, "BeginSavepoint not implemented") } func (BaseServer) CancelFlightInfo(context.Context, *flight.CancelFlightInfoRequest) (flight.CancelFlightInfoResult, error) { return flight.CancelFlightInfoResult{Status: flight.CancelStatusUnspecified}, status.Error(codes.Unimplemented, "CancelFlightInfo not implemented") } func (BaseServer) RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpointRequest) (*flight.FlightEndpoint, error) { return nil, status.Error(codes.Unimplemented, "RenewFlightEndpoint not implemented") } func (BaseServer) PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error) { return nil, status.Error(codes.Unimplemented, "PollFlightInfo not implemented") } func (BaseServer) PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) { return nil, status.Error(codes.Unimplemented, "PollFlightInfoStatement not implemented") } func (BaseServer) PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) { return nil, status.Error(codes.Unimplemented, "PollFlightInfoSubstraitPlan not implemented") } func (BaseServer) PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) { return nil, status.Error(codes.Unimplemented, "PollFlightInfoPreparedStatement not implemented") } func (BaseServer) EndTransaction(context.Context, ActionEndTransactionRequest) error { return status.Error(codes.Unimplemented, "EndTransaction not implemented") } func (BaseServer) EndSavepoint(context.Context, ActionEndSavepointRequest) error { return status.Error(codes.Unimplemented, "EndSavepoint not implemented") } func (BaseServer) SetSessionOptions(context.Context, *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) { return nil, status.Error(codes.Unimplemented, "SetSessionOptions not implemented") } func (BaseServer) GetSessionOptions(context.Context, *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) { return nil, status.Error(codes.Unimplemented, "GetSessionOptions not implemented") } func (BaseServer) CloseSession(context.Context, *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) { return nil, status.Error(codes.Unimplemented, "CloseSession not implemented") } // Server is the required interface for a FlightSQL server. It is implemented by // BaseServer which must be embedded in any implementation. The default // implementation by BaseServer for each of these (except GetSqlInfo) // // GetFlightInfo* methods should return the FlightInfo object representing where // to retrieve the results for a given request. // // DoGet* methods should return the Schema of the resulting stream along with // a channel to retrieve stream chunks (each chunk is a record batch and optionally // a descriptor and app metadata). The channel will be read from until it // closes, sending each chunk on the stream. Since the channel is returned // from the method, it should be populated within a goroutine to ensure // there are no deadlocks. type Server interface { // GetFlightInfoStatement returns a FlightInfo for executing the requested sql query GetFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error) // GetFlightInfoSubstraitPlan returns a FlightInfo for executing the requested substrait plan GetFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.FlightInfo, error) // GetSchemaStatement returns the schema of the result set of the requested sql query GetSchemaStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) // GetSchemaSubstraitPlan returns the schema of the result set for the requested substrait plan GetSchemaSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.SchemaResult, error) // DoGetStatement returns a stream containing the query results for the // requested statement handle that was populated by GetFlightInfoStatement DoGetStatement(context.Context, StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoPreparedStatement returns a FlightInfo for executing an already // prepared statement with the provided statement handle. GetFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error) // GetSchemaPreparedStatement returns the schema of the result set of executing an already // prepared statement with the provided statement handle. GetSchemaPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) // DoGetPreparedStatement returns a stream containing the results from executing // a prepared statement query with the provided statement handle. DoGetPreparedStatement(context.Context, PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoCatalogs returns a FlightInfo for the listing of all catalogs GetFlightInfoCatalogs(context.Context, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetCatalogs returns the stream containing the list of catalogs DoGetCatalogs(context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoXdbcTypeInfo returns a FlightInfo for retrieving data type info GetFlightInfoXdbcTypeInfo(context.Context, GetXdbcTypeInfo, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetXdbcTypeInfo returns a stream containing the information about the // requested supported datatypes DoGetXdbcTypeInfo(context.Context, GetXdbcTypeInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoSqlInfo returns a FlightInfo for retrieving SqlInfo from the server GetFlightInfoSqlInfo(context.Context, GetSqlInfo, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetSqlInfo returns a stream containing the list of SqlInfo results DoGetSqlInfo(context.Context, GetSqlInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoSchemas returns a FlightInfo for requesting a list of schemas GetFlightInfoSchemas(context.Context, GetDBSchemas, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetDBSchemas returns a stream containing the list of schemas DoGetDBSchemas(context.Context, GetDBSchemas) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoTables returns a FlightInfo for listing the tables available GetFlightInfoTables(context.Context, GetTables, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetTables returns a stream containing the list of tables DoGetTables(context.Context, GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoTableTypes returns a FlightInfo for retrieving a list // of table types supported GetFlightInfoTableTypes(context.Context, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetTableTypes returns a stream containing the data related to the table types DoGetTableTypes(context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoPrimaryKeys returns a FlightInfo for extracting information about primary keys GetFlightInfoPrimaryKeys(context.Context, TableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetPrimaryKeys returns a stream containing the data related to primary keys DoGetPrimaryKeys(context.Context, TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoExportedKeys returns a FlightInfo for extracting information about foreign keys GetFlightInfoExportedKeys(context.Context, TableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetExportedKeys returns a stream containing the data related to foreign keys DoGetExportedKeys(context.Context, TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoImportedKeys returns a FlightInfo for extracting information about imported keys GetFlightInfoImportedKeys(context.Context, TableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetImportedKeys returns a stream containing the data related to imported keys DoGetImportedKeys(context.Context, TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoCrossReference returns a FlightInfo for extracting data related // to primary and foreign keys GetFlightInfoCrossReference(context.Context, CrossTableRef, *flight.FlightDescriptor) (*flight.FlightInfo, error) // DoGetCrossReference returns a stream of data related to foreign and primary keys DoGetCrossReference(context.Context, CrossTableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) // DoPutCommandStatementUpdate executes a sql update statement and returns // the number of affected rows DoPutCommandStatementUpdate(context.Context, StatementUpdate) (int64, error) // DoPutCommandSubstraitPlan executes a substrait plan and returns the number // of affected rows. DoPutCommandSubstraitPlan(context.Context, StatementSubstraitPlan) (int64, error) // CreatePreparedStatement constructs a prepared statement from a sql query // and returns an opaque statement handle for use. CreatePreparedStatement(context.Context, ActionCreatePreparedStatementRequest) (ActionCreatePreparedStatementResult, error) // CreatePreparedSubstraitPlan constructs a prepared statement from a substrait // plan, and returns an opaque statement handle for use. CreatePreparedSubstraitPlan(context.Context, ActionCreatePreparedSubstraitPlanRequest) (ActionCreatePreparedStatementResult, error) // ClosePreparedStatement closes the prepared statement identified by the requested // opaque statement handle. ClosePreparedStatement(context.Context, ActionClosePreparedStatementRequest) error // DoPutPreparedStatementQuery binds parameters to a given prepared statement // identified by the provided statement handle. // // The provided MessageReader is a stream of record batches with optional // app metadata and flight descriptors to represent the values to bind // to the parameters. // // Currently anything written to the writer will be ignored. It is in the // interface for potential future enhancements to avoid having to change // the interface in the future. DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error) // DoPutPreparedStatementUpdate executes an update SQL Prepared statement // for the specified statement handle. The reader allows providing a sequence // of uploaded record batches to bind the parameters to. Returns the number // of affected records. DoPutPreparedStatementUpdate(context.Context, PreparedStatementUpdate, flight.MessageReader) (int64, error) // BeginTransaction starts a new transaction and returns the id BeginTransaction(context.Context, ActionBeginTransactionRequest) (id []byte, err error) // BeginSavepoint initializes a new savepoint and returns the id BeginSavepoint(context.Context, ActionBeginSavepointRequest) (id []byte, err error) // EndSavepoint releases or rolls back a savepoint EndSavepoint(context.Context, ActionEndSavepointRequest) error // EndTransaction commits or rolls back a transaction EndTransaction(context.Context, ActionEndTransactionRequest) error // CancelFlightInfo attempts to explicitly cancel a FlightInfo CancelFlightInfo(context.Context, *flight.CancelFlightInfoRequest) (flight.CancelFlightInfoResult, error) // RenewFlightEndpoint attempts to extend the expiration of a FlightEndpoint RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpointRequest) (*flight.FlightEndpoint, error) // PollFlightInfo is a generic handler for PollFlightInfo requests. PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error) // PollFlightInfoStatement handles polling for query execution. PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) // PollFlightInfoSubstraitPlan handles polling for query execution. PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) // PollFlightInfoPreparedStatement handles polling for query execution. PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) // SetSessionOptions sets option(s) for the current server session. SetSessionOptions(context.Context, *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) // GetSessionOptions gets option(s) for the current server session. GetSessionOptions(context.Context, *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) // CloseSession closes/invalidates the current server session. CloseSession(context.Context, *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) // DoPutCommandStatementIngest executes a bulk ingestion and returns // the number of affected rows DoPutCommandStatementIngest(context.Context, StatementIngest, flight.MessageReader) (int64, error) mustEmbedBaseServer() } // NewFlightServer constructs a FlightRPC server from the provided // FlightSQL Server so that it can be passed to RegisterFlightService. func NewFlightServer(srv Server) flight.FlightServer { return &flightSqlServer{srv: srv, mem: memory.DefaultAllocator} } // NewFlightServerWithAllocator constructs a FlightRPC server from // the provided FlightSQL Server so that it can be passed to // RegisterFlightService, setting the provided allocator into the server // for use with any allocations necessary by the routing. // // Will default to memory.DefaultAllocator if mem is nil func NewFlightServerWithAllocator(srv Server, mem memory.Allocator) flight.FlightServer { if mem == nil { mem = memory.DefaultAllocator } return &flightSqlServer{srv: srv, mem: mem} } // flightSqlServer is a wrapper around a FlightSQL server interface to // perform routing from FlightRPC to FlightSQL. type flightSqlServer struct { flight.BaseFlightServer mem memory.Allocator srv Server } func (f *flightSqlServer) GetFlightInfo(ctx context.Context, request *flight.FlightDescriptor) (*flight.FlightInfo, error) { var ( anycmd anypb.Any cmd proto.Message err error ) if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil { return nil, status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } if cmd, err = anycmd.UnmarshalNew(); err != nil { return nil, status.Errorf(codes.InvalidArgument, "could not unmarshal Any to a command type: %s", err.Error()) } switch cmd := cmd.(type) { case *pb.CommandStatementQuery: return f.srv.GetFlightInfoStatement(ctx, cmd, request) case *pb.CommandStatementSubstraitPlan: return f.srv.GetFlightInfoSubstraitPlan(ctx, &statementSubstraitPlan{cmd}, request) case *pb.CommandPreparedStatementQuery: return f.srv.GetFlightInfoPreparedStatement(ctx, cmd, request) case *pb.CommandGetCatalogs: return f.srv.GetFlightInfoCatalogs(ctx, request) case *pb.CommandGetDbSchemas: return f.srv.GetFlightInfoSchemas(ctx, &getDBSchemas{cmd}, request) case *pb.CommandGetTables: return f.srv.GetFlightInfoTables(ctx, &getTables{cmd}, request) case *pb.CommandGetTableTypes: return f.srv.GetFlightInfoTableTypes(ctx, request) case *pb.CommandGetXdbcTypeInfo: return f.srv.GetFlightInfoXdbcTypeInfo(ctx, &getXdbcTypeInfo{cmd}, request) case *pb.CommandGetSqlInfo: return f.srv.GetFlightInfoSqlInfo(ctx, cmd, request) case *pb.CommandGetPrimaryKeys: return f.srv.GetFlightInfoPrimaryKeys(ctx, pkToTableRef(cmd), request) case *pb.CommandGetExportedKeys: return f.srv.GetFlightInfoExportedKeys(ctx, exkToTableRef(cmd), request) case *pb.CommandGetImportedKeys: return f.srv.GetFlightInfoImportedKeys(ctx, impkToTableRef(cmd), request) case *pb.CommandGetCrossReference: return f.srv.GetFlightInfoCrossReference(ctx, toCrossTableRef(cmd), request) } return nil, status.Error(codes.InvalidArgument, "requested command is invalid") } func (f *flightSqlServer) PollFlightInfo(ctx context.Context, request *flight.FlightDescriptor) (*flight.PollInfo, error) { var ( anycmd anypb.Any cmd proto.Message err error ) // If we can't parse things, be friendly and defer to the server // implementation. This is especially important for this method since // the server returns a custom FlightDescriptor for future requests. if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil { return f.srv.PollFlightInfo(ctx, request) } if cmd, err = anycmd.UnmarshalNew(); err != nil { return f.srv.PollFlightInfo(ctx, request) } switch cmd := cmd.(type) { case *pb.CommandStatementQuery: return f.srv.PollFlightInfoStatement(ctx, cmd, request) case *pb.CommandStatementSubstraitPlan: return f.srv.PollFlightInfoSubstraitPlan(ctx, &statementSubstraitPlan{cmd}, request) case *pb.CommandPreparedStatementQuery: return f.srv.PollFlightInfoPreparedStatement(ctx, cmd, request) } // XXX: for now we won't support the other methods return f.srv.PollFlightInfo(ctx, request) } func (f *flightSqlServer) GetSchema(ctx context.Context, request *flight.FlightDescriptor) (*flight.SchemaResult, error) { var ( anycmd anypb.Any cmd proto.Message err error ) if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil { return nil, status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } if cmd, err = anycmd.UnmarshalNew(); err != nil { return nil, status.Errorf(codes.InvalidArgument, "could not unmarshal Any to a command type: %s", err.Error()) } switch cmd := cmd.(type) { case *pb.CommandStatementQuery: return f.srv.GetSchemaStatement(ctx, cmd, request) case *pb.CommandStatementSubstraitPlan: return f.srv.GetSchemaSubstraitPlan(ctx, &statementSubstraitPlan{cmd}, request) case *pb.CommandPreparedStatementQuery: return f.srv.GetSchemaPreparedStatement(ctx, cmd, request) case *pb.CommandGetCatalogs: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.Catalogs, f.mem)}, nil case *pb.CommandGetDbSchemas: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.DBSchemas, f.mem)}, nil case *pb.CommandGetTables: if cmd.GetIncludeSchema() { return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.TablesWithIncludedSchema, f.mem)}, nil } return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.Tables, f.mem)}, nil case *pb.CommandGetTableTypes: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.TableTypes, f.mem)}, nil case *pb.CommandGetXdbcTypeInfo: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.XdbcTypeInfo, f.mem)}, nil case *pb.CommandGetSqlInfo: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.SqlInfo, f.mem)}, nil case *pb.CommandGetPrimaryKeys: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.PrimaryKeys, f.mem)}, nil case *pb.CommandGetExportedKeys: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.ExportedKeys, f.mem)}, nil case *pb.CommandGetImportedKeys: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.ImportedKeys, f.mem)}, nil case *pb.CommandGetCrossReference: return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.CrossReference, f.mem)}, nil } return nil, status.Errorf(codes.InvalidArgument, "requested command is invalid: %s", anycmd.GetTypeUrl()) } func (f *flightSqlServer) DoGet(request *flight.Ticket, stream flight.FlightService_DoGetServer) (err error) { var ( anycmd anypb.Any cmd proto.Message cc <-chan flight.StreamChunk sc *arrow.Schema ) if err = proto.Unmarshal(request.Ticket, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse ticket: %s", err.Error()) } if cmd, err = anycmd.UnmarshalNew(); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal proto.Any: %s", err.Error()) } switch cmd := cmd.(type) { case *pb.TicketStatementQuery: sc, cc, err = f.srv.DoGetStatement(stream.Context(), cmd) case *pb.CommandPreparedStatementQuery: sc, cc, err = f.srv.DoGetPreparedStatement(stream.Context(), cmd) case *pb.CommandGetCatalogs: sc, cc, err = f.srv.DoGetCatalogs(stream.Context()) case *pb.CommandGetDbSchemas: sc, cc, err = f.srv.DoGetDBSchemas(stream.Context(), &getDBSchemas{cmd}) case *pb.CommandGetTables: sc, cc, err = f.srv.DoGetTables(stream.Context(), &getTables{cmd}) case *pb.CommandGetTableTypes: sc, cc, err = f.srv.DoGetTableTypes(stream.Context()) case *pb.CommandGetXdbcTypeInfo: sc, cc, err = f.srv.DoGetXdbcTypeInfo(stream.Context(), &getXdbcTypeInfo{cmd}) case *pb.CommandGetSqlInfo: sc, cc, err = f.srv.DoGetSqlInfo(stream.Context(), cmd) case *pb.CommandGetPrimaryKeys: sc, cc, err = f.srv.DoGetPrimaryKeys(stream.Context(), pkToTableRef(cmd)) case *pb.CommandGetExportedKeys: sc, cc, err = f.srv.DoGetExportedKeys(stream.Context(), exkToTableRef(cmd)) case *pb.CommandGetImportedKeys: sc, cc, err = f.srv.DoGetImportedKeys(stream.Context(), impkToTableRef(cmd)) case *pb.CommandGetCrossReference: sc, cc, err = f.srv.DoGetCrossReference(stream.Context(), toCrossTableRef(cmd)) default: return status.Error(codes.InvalidArgument, "requested command is invalid") } if err != nil { return err } wr := flight.NewRecordWriter(stream, ipc.WithSchema(sc)) defer wr.Close() for chunk := range cc { if chunk.Err != nil { return chunk.Err } wr.SetFlightDescriptor(chunk.Desc) if err = wr.WriteWithAppMetadata(chunk.Data, chunk.AppMetadata); err != nil { return err } chunk.Data.Release() } return err } type putMetadataWriter struct { stream flight.FlightService_DoPutServer } func (p *putMetadataWriter) WriteMetadata(appMetadata []byte) error { return p.stream.Send(&flight.PutResult{AppMetadata: appMetadata}) } func (f *flightSqlServer) DoPut(stream flight.FlightService_DoPutServer) error { rdr, err := flight.NewRecordReader(stream, ipc.WithAllocator(f.mem), ipc.WithDelayReadSchema(true)) if err != nil { return status.Errorf(codes.InvalidArgument, "failed to read input stream: %s", err.Error()) } defer rdr.Release() // flight descriptor should have come with the schema message request := rdr.LatestFlightDescriptor() var ( anycmd anypb.Any cmd proto.Message ) if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } if cmd, err = anycmd.UnmarshalNew(); err != nil { return status.Errorf(codes.InvalidArgument, "could not unmarshal google.protobuf.Any: %s", err.Error()) } switch cmd := cmd.(type) { case *pb.CommandStatementUpdate: recordCount, err := f.srv.DoPutCommandStatementUpdate(stream.Context(), cmd) if err != nil { return err } result := pb.DoPutUpdateResult{RecordCount: recordCount} out := &flight.PutResult{} if out.AppMetadata, err = proto.Marshal(&result); err != nil { return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error()) } return stream.Send(out) case *pb.CommandStatementSubstraitPlan: recordCount, err := f.srv.DoPutCommandSubstraitPlan(stream.Context(), &statementSubstraitPlan{cmd}) if err != nil { return err } result := pb.DoPutUpdateResult{RecordCount: recordCount} out := &flight.PutResult{} if out.AppMetadata, err = proto.Marshal(&result); err != nil { return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error()) } return stream.Send(out) case *pb.CommandPreparedStatementQuery: handle, err := f.srv.DoPutPreparedStatementQuery(stream.Context(), cmd, rdr, &putMetadataWriter{stream}) if err != nil { return err } result := pb.DoPutPreparedStatementResult{PreparedStatementHandle: handle} out := &flight.PutResult{} if out.AppMetadata, err = proto.Marshal(&result); err != nil { return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error()) } return stream.Send(out) case *pb.CommandPreparedStatementUpdate: recordCount, err := f.srv.DoPutPreparedStatementUpdate(stream.Context(), cmd, rdr) if err != nil { return err } result := pb.DoPutUpdateResult{RecordCount: recordCount} out := &flight.PutResult{} if out.AppMetadata, err = proto.Marshal(&result); err != nil { return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error()) } return stream.Send(out) case *pb.CommandStatementIngest: // Even if there was an error, the server may have ingested some records. // For this reason we send PutResult{recordCount} no matter what, potentially followed by an error // if there was one. recordCount, rpcErr := f.srv.DoPutCommandStatementIngest(stream.Context(), cmd, rdr) result := pb.DoPutUpdateResult{RecordCount: recordCount} out := &flight.PutResult{} if out.AppMetadata, err = proto.Marshal(&result); err != nil { return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error()) } // If we fail to send the recordCount, just return an error outright if err := stream.Send(out); err != nil { return err } // We successfully sent the recordCount. // Send the error if one occurred in the RPC, otherwise this is nil. return rpcErr default: return status.Error(codes.InvalidArgument, "the defined request is invalid") } } func (f *flightSqlServer) ListActions(_ *flight.Empty, stream flight.FlightService_ListActionsServer) error { actions := []string{ flight.CancelFlightInfoActionType, flight.RenewFlightEndpointActionType, CreatePreparedStatementActionType, ClosePreparedStatementActionType, BeginSavepointActionType, BeginTransactionActionType, CancelQueryActionType, CreatePreparedSubstraitPlanActionType, EndSavepointActionType, EndTransactionActionType, } for _, a := range actions { if err := stream.Send(&flight.ActionType{Type: a}); err != nil { return err } } return nil } func cancelStatusToCancelResult(status flight.CancelStatus) CancelResult { switch status { case flight.CancelStatusUnspecified: return CancelResultUnspecified case flight.CancelStatusCancelled: return CancelResultCancelled case flight.CancelStatusCancelling: return CancelResultCancelling case flight.CancelStatusNotCancellable: return CancelResultNotCancellable default: return CancelResultUnspecified } } func (f *flightSqlServer) DoAction(cmd *flight.Action, stream flight.FlightService_DoActionServer) error { var anycmd anypb.Any switch cmd.Type { case flight.CancelFlightInfoActionType: var ( request flight.CancelFlightInfoRequest result flight.CancelFlightInfoResult err error ) if err = proto.Unmarshal(cmd.Body, &request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal CancelFlightInfoRequest for CancelFlightInfo: %s", err.Error()) } result, err = f.srv.CancelFlightInfo(stream.Context(), &request) if err != nil { return err } out := &pb.Result{} out.Body, err = proto.Marshal(&result) if err != nil { return err } return stream.Send(out) case flight.RenewFlightEndpointActionType: var ( request flight.RenewFlightEndpointRequest err error ) if err = proto.Unmarshal(cmd.Body, &request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal FlightEndpoint for RenewFlightEndpoint: %s", err.Error()) } renewedEndpoint, err := f.srv.RenewFlightEndpoint(stream.Context(), &request) if err != nil { return err } out := &pb.Result{} out.Body, err = proto.Marshal(renewedEndpoint) if err != nil { return err } return stream.Send(out) case BeginSavepointActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var ( request pb.ActionBeginSavepointRequest result pb.ActionBeginSavepointResult id []byte err error ) if err = anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } if id, err = f.srv.BeginSavepoint(stream.Context(), &request); err != nil { return err } result.SavepointId = id out, err := packActionResult(&result) if err != nil { return err } return stream.Send(out) case BeginTransactionActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var ( request pb.ActionBeginTransactionRequest result pb.ActionBeginTransactionResult id []byte err error ) if err = anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } if id, err = f.srv.BeginTransaction(stream.Context(), &request); err != nil { return err } result.TransactionId = id out, err := packActionResult(&result) if err != nil { return err } return stream.Send(out) case CancelQueryActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var ( //nolint:staticcheck,SA1019 for backward compatibility request pb.ActionCancelQueryRequest //nolint:staticcheck,SA1019 for backward compatibility result pb.ActionCancelQueryResult info flight.FlightInfo err error ) if err = anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } if err = proto.Unmarshal(request.Info, &info); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal FlightInfo for CancelQuery: %s", err) } if cancel, ok := f.srv.(cancelQueryServer); ok { result.Result, err = cancel.CancelQuery(stream.Context(), &cancelQueryRequest{&info}) if err != nil { return err } } else { cancelFlightInfoRequest := flight.CancelFlightInfoRequest{Info: &info} cancelFlightInfoResult, err := f.srv.CancelFlightInfo(stream.Context(), &cancelFlightInfoRequest) if err != nil { return err } result.Result = cancelStatusToCancelResult(cancelFlightInfoResult.Status) } out, err := packActionResult(&result) if err != nil { return err } return stream.Send(out) case CreatePreparedStatementActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var ( request pb.ActionCreatePreparedStatementRequest result pb.ActionCreatePreparedStatementResult ret pb.Result ) if err := anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } output, err := f.srv.CreatePreparedStatement(stream.Context(), &request) if err != nil { return err } result.PreparedStatementHandle = output.Handle if output.DatasetSchema != nil { result.DatasetSchema = flight.SerializeSchema(output.DatasetSchema, f.mem) } if output.ParameterSchema != nil { result.ParameterSchema = flight.SerializeSchema(output.ParameterSchema, f.mem) } if err := anycmd.MarshalFrom(&result); err != nil { return status.Errorf(codes.Internal, "unable to marshal final response: %s", err.Error()) } if ret.Body, err = proto.Marshal(&anycmd); err != nil { return status.Errorf(codes.Internal, "unable to marshal result: %s", err.Error()) } return stream.Send(&ret) case CreatePreparedSubstraitPlanActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var ( request pb.ActionCreatePreparedSubstraitPlanRequest result pb.ActionCreatePreparedStatementResult ret pb.Result ) if err := anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } output, err := f.srv.CreatePreparedSubstraitPlan(stream.Context(), &createPreparedSubstraitPlanReq{&request}) if err != nil { return err } result.PreparedStatementHandle = output.Handle if output.DatasetSchema != nil { result.DatasetSchema = flight.SerializeSchema(output.DatasetSchema, f.mem) } if output.ParameterSchema != nil { result.ParameterSchema = flight.SerializeSchema(output.ParameterSchema, f.mem) } if err := anycmd.MarshalFrom(&result); err != nil { return status.Errorf(codes.Internal, "unable to marshal final response: %s", err.Error()) } if ret.Body, err = proto.Marshal(&anycmd); err != nil { return status.Errorf(codes.Internal, "unable to marshal result: %s", err.Error()) } return stream.Send(&ret) case ClosePreparedStatementActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var request pb.ActionClosePreparedStatementRequest if err := anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } if err := f.srv.ClosePreparedStatement(stream.Context(), &request); err != nil { return err } return stream.Send(&pb.Result{}) case EndTransactionActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var request pb.ActionEndTransactionRequest if err := anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } if err := f.srv.EndTransaction(stream.Context(), &request); err != nil { return err } return stream.Send(&pb.Result{}) case EndSavepointActionType: if err := proto.Unmarshal(cmd.Body, &anycmd); err != nil { return status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) } var request pb.ActionEndSavepointRequest if err := anycmd.UnmarshalTo(&request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal google.protobuf.Any: %s", err.Error()) } if err := f.srv.EndSavepoint(stream.Context(), &request); err != nil { return err } return stream.Send(&pb.Result{}) case flight.SetSessionOptionsActionType: var ( request flight.SetSessionOptionsRequest err error ) if err = proto.Unmarshal(cmd.Body, &request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal SetSessionOptionsRequest: %s", err.Error()) } response, err := f.srv.SetSessionOptions(stream.Context(), &request) if err != nil { return err } out := &pb.Result{} out.Body, err = proto.Marshal(response) if err != nil { return err } return stream.Send(out) case flight.GetSessionOptionsActionType: var ( request flight.GetSessionOptionsRequest err error ) if err = proto.Unmarshal(cmd.Body, &request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal GetSessionOptionsRequest: %s", err.Error()) } response, err := f.srv.GetSessionOptions(stream.Context(), &request) if err != nil { return err } out := &pb.Result{} out.Body, err = proto.Marshal(response) if err != nil { return err } return stream.Send(out) case flight.CloseSessionActionType: var ( request flight.CloseSessionRequest err error ) if err = proto.Unmarshal(cmd.Body, &request); err != nil { return status.Errorf(codes.InvalidArgument, "unable to unmarshal CloseSessionRequest: %s", err.Error()) } response, err := f.srv.CloseSession(stream.Context(), &request) if err != nil { return err } out := &pb.Result{} out.Body, err = proto.Marshal(response) if err != nil { return err } return stream.Send(out) default: return status.Error(codes.InvalidArgument, "the defined request is invalid.") } } var ( _ Server = (*BaseServer)(nil) )