go/adbc/driver/flightsql/flightsql_statement.go (286 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package flightsql import ( "context" "fmt" "strconv" "strings" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/flight" "github.com/apache/arrow/go/v13/arrow/flight/flightsql" "github.com/apache/arrow/go/v13/arrow/memory" "github.com/bluele/gcache" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" ) const ( OptionStatementQueueSize = "adbc.rpc.result_queue_size" // Explicitly set substrait version for Flight SQL // substrait *does* include the version in the serialized plan // so this is not entirely necessary depending on the version // of substrait and the capabilities of the server. OptionStatementSubstraitVersion = "adbc.flight.sql.substrait.version" ) type sqlOrSubstrait struct { sqlQuery string substraitPlan []byte substraitVersion string } func (s *sqlOrSubstrait) setSqlQuery(query string) { s.sqlQuery = query s.substraitPlan = nil } func (s *sqlOrSubstrait) setSubstraitPlan(plan []byte) { s.sqlQuery = "" s.substraitPlan = plan } func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if s.sqlQuery != "" { return cnxn.execute(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { return cnxn.executeSubstrait(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) } return nil, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement", } } func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { return cnxn.executeSubstraitUpdate(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) } return -1, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[Flight SQL Statement] cannot call ExecuteUpdate without a query or prepared statement", } } func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if s.sqlQuery != "" { return cnxn.prepare(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { return cnxn.prepareSubstrait(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) } return nil, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[FlightSQL Statement] must call SetSqlQuery before Prepare", } } type statement struct { alloc memory.Allocator cnxn *cnxn clientCache gcache.Cache hdrs metadata.MD query sqlOrSubstrait prepared *flightsql.PreparedStatement queueSize int timeouts timeoutOption } func (s *statement) closePreparedStatement() error { return s.prepared.Close(metadata.NewOutgoingContext(context.Background(), s.hdrs), s.timeouts) } // Close releases any relevant resources associated with this statement // and closes it (particularly if it is a prepared statement). // // A statement instance should not be used after Close is called. func (s *statement) Close() (err error) { if s.prepared != nil { err = s.closePreparedStatement() s.prepared = nil } if s.cnxn == nil { return adbc.Error{ Msg: "[Flight SQL Statement] cannot close already closed statement", Code: adbc.StatusInvalidState, } } s.clientCache = nil s.cnxn = nil return err } // SetOption sets a string option on this statement func (s *statement) SetOption(key string, val string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) if val == "" { s.hdrs.Delete(name) } else { s.hdrs.Append(name, val) } return nil } switch key { case OptionTimeoutFetch: timeout, err := getTimeoutOptionValue(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", OptionTimeoutFetch, val, err.Error()), Code: adbc.StatusInvalidArgument, } } s.timeouts.fetchTimeout = timeout case OptionTimeoutQuery: timeout, err := getTimeoutOptionValue(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", OptionTimeoutFetch, val, err.Error()), Code: adbc.StatusInvalidArgument, } } s.timeouts.queryTimeout = timeout case OptionTimeoutUpdate: timeout, err := getTimeoutOptionValue(val) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", OptionTimeoutFetch, val, err.Error()), Code: adbc.StatusInvalidArgument, } } s.timeouts.updateTimeout = timeout case OptionStatementQueueSize: var err error var size int if size, err = strconv.Atoi(val); err != nil { return adbc.Error{ Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), Code: adbc.StatusInvalidArgument, } } else if size <= 0 { return adbc.Error{ Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), Code: adbc.StatusInvalidArgument, } } s.queueSize = size case OptionStatementSubstraitVersion: s.query.substraitVersion = val default: return adbc.Error{ Msg: "[Flight SQL] Unknown statement option '" + key + "'", Code: adbc.StatusNotImplemented, } } return nil } // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. // For queries expected to be executed repeatedly, Prepare should be // called before execution. func (s *statement) SetSqlQuery(query string) error { if s.prepared != nil { if err := s.closePreparedStatement(); err != nil { return err } s.prepared = nil } s.query.setSqlQuery(query) return nil } // ExecuteQuery executes the current query or prepared statement // and returnes a RecordReader for the results along with the number // of rows affected if known, otherwise it will be -1. // // This invalidates any prior result sets on this statement. func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, nrec int64, err error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var info *flight.FlightInfo if s.prepared != nil { info, err = s.prepared.Execute(ctx, s.timeouts) } else { info, err = s.query.execute(ctx, s.cnxn, s.timeouts) } if err != nil { return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery") } nrec = info.TotalRecords rdr, err = newRecordReader(ctx, s.alloc, s.cnxn.cl, info, s.clientCache, s.queueSize, s.timeouts) return } // ExecuteUpdate executes a statement that does not generate a result // set. It returns the number of rows affected if known, otherwise -1. func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) if s.prepared != nil { n, err = s.prepared.ExecuteUpdate(ctx, s.timeouts) } else { n, err = s.query.executeUpdate(ctx, s.cnxn, s.timeouts) } if err != nil { err = adbcFromFlightStatus(err, "ExecuteUpdate") } return } // Prepare turns this statement into a prepared statement to be executed // multiple times. This invalidates any prior result sets. func (s *statement) Prepare(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts) if err != nil { return adbcFromFlightStatus(err, "Prepare") } s.prepared = prep return nil } // SetSubstraitPlan allows setting a serialized Substrait execution // plan into the query or for querying Substrait-related metadata. // // Drivers are not required to support both SQL and Substrait semantics. // If they do, it may be via converting between representations internally. // // Like SetSqlQuery, after this is called the query can be executed // using any of the Execute methods. If the query is expected to be // executed repeatedly, Prepare should be called first on the statement. func (s *statement) SetSubstraitPlan(plan []byte) error { if s.prepared != nil { if err := s.closePreparedStatement(); err != nil { return err } s.prepared = nil } s.query.setSubstraitPlan(plan) return nil } // Bind uses an arrow record batch to bind parameters to the query. // // This can be used for bulk inserts or for prepared statements. // The driver will call release on the passed in Record when it is done, // but it may not do this until the statement is closed or another // record is bound. func (s *statement) Bind(_ context.Context, values arrow.Record) error { // TODO: handle bulk insert situation if s.prepared == nil { return adbc.Error{ Msg: "[Flight SQL Statement] must call Prepare before calling Bind", Code: adbc.StatusInvalidState} } // calls retain s.prepared.SetParameters(values) return nil } // BindStream uses a record batch stream to bind parameters for this // query. This can be used for bulk inserts or prepared statements. // // The driver will call Release on the record reader, but may not do this // until Close is called. func (s *statement) BindStream(_ context.Context, stream array.RecordReader) error { if s.prepared == nil { return adbc.Error{ Msg: "[Flight SQL Statement] must call Prepare before calling Bind", Code: adbc.StatusInvalidState} } // calls retain s.prepared.SetRecordReader(stream) return nil } // GetParameterSchema returns an Arrow schema representation of // the expected parameters to be bound. // // This retrieves an Arrow Schema describing the number, names, and // types of the parameters in a parameterized statement. The fields // of the schema should be in order of the ordinal position of the // parameters; named parameters should appear only once. // // If the parameter does not have a name, or a name cannot be determined, // the name of the corresponding field in the schema will be an empty // string. If the type cannot be determined, the type of the corresponding // field will be NA (NullType). // // This should be called only after calling Prepare. // // This should return an error with StatusNotImplemented if the schema // cannot be determined. func (s *statement) GetParameterSchema() (*arrow.Schema, error) { if s.prepared == nil { return nil, adbc.Error{ Msg: "[Flight SQL Statement] must call Prepare before GetParameterSchema", Code: adbc.StatusInvalidState, } } ret := s.prepared.ParameterSchema() if ret == nil { return nil, adbc.Error{Code: adbc.StatusNotImplemented} } return ret, nil } // ExecutePartitions executes the current statement and gets the results // as a partitioned result set. // // It returns the Schema of the result set (if available, nil otherwise), // the collection of partition descriptors and the number of rows affected, // if known. If unknown, the number of rows affected will be -1. // // If the driver does not support partitioned results, this will return // an error with a StatusNotImplemented code. func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.Partitions, int64, error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var ( info *flight.FlightInfo out adbc.Partitions sc *arrow.Schema err error ) if s.prepared != nil { info, err = s.prepared.Execute(ctx, s.timeouts) } else { info, err = s.query.execute(ctx, s.cnxn, s.timeouts) } if err != nil { return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions") } if len(info.Schema) > 0 { sc, err = flight.DeserializeSchema(info.Schema, s.alloc) if err != nil { return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions: could not deserialize FlightInfo schema:") } } out.NumPartitions = uint64(len(info.Endpoint)) out.PartitionIDs = make([][]byte, out.NumPartitions) for i, e := range info.Endpoint { partition := proto.Clone(info).(*flight.FlightInfo) partition.Endpoint = []*flight.FlightEndpoint{e} data, err := proto.Marshal(partition) if err != nil { return sc, out, -1, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } out.PartitionIDs[i] = data } return sc, out, info.TotalRecords, nil }