go/adbc/driver/flightsql/flightsql_statement.go (560 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package flightsql import ( "context" "fmt" "math" "strconv" "strings" "sync/atomic" "time" "unsafe" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/bluele/gcache" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" ) const ( OptionStatementQueueSize = "adbc.rpc.result_queue_size" // Explicitly set substrait version for Flight SQL // substrait *does* include the version in the serialized plan // so this is not entirely necessary depending on the version // of substrait and the capabilities of the server. OptionStatementSubstraitVersion = "adbc.flight.sql.substrait.version" ) func atomicLoadFloat64(x *float64) float64 { return math.Float64frombits(atomic.LoadUint64((*uint64)(unsafe.Pointer(x)))) } func atomicStoreFloat64(x *float64, v float64) { atomic.StoreUint64((*uint64)(unsafe.Pointer(x)), math.Float64bits(v)) } type sqlOrSubstrait struct { sqlQuery string substraitPlan []byte substraitVersion string } func (s *sqlOrSubstrait) setSqlQuery(query string) { s.sqlQuery = query s.substraitPlan = nil } func (s *sqlOrSubstrait) setSubstraitPlan(plan []byte) { s.sqlQuery = "" s.substraitPlan = plan } func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if s.sqlQuery != "" { return cnxn.execute(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { return cnxn.executeSubstrait(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) } return nil, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement", } } func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*arrow.Schema, error) { var ( res *flight.SchemaResult err error ) if s.sqlQuery != "" { res, err = cnxn.executeSchema(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { res, err = cnxn.executeSubstraitSchema(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) } else { return nil, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement", } } if err != nil { return nil, err } return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc) } func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { return cnxn.executeSubstraitUpdate(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) } return -1, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[Flight SQL Statement] cannot call ExecuteUpdate without a query or prepared statement", } } func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *connectionImpl, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if s.sqlQuery != "" { return cnxn.poll(ctx, s.sqlQuery, retryDescriptor, opts...) } else if s.substraitPlan != nil { return cnxn.pollSubstrait(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, retryDescriptor, opts...) } return nil, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[Flight SQL] cannot call ExecuteQuery without a query or prepared statement", } } func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if s.sqlQuery != "" { return cnxn.prepare(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { return cnxn.prepareSubstrait(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) } return nil, adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[FlightSQL Statement] must call SetSqlQuery before Prepare", } } type incrementalState struct { schema *arrow.Schema previousInfo *flight.FlightInfo retryDescriptor *flight.FlightDescriptor complete bool } type statement struct { alloc memory.Allocator cnxn *connectionImpl clientCache gcache.Cache hdrs metadata.MD query sqlOrSubstrait prepared *flightsql.PreparedStatement queueSize int timeouts timeoutOption incrementalState *incrementalState progress float64 // may seem redundant, but incrementalState isn't locked lastInfo atomic.Pointer[flight.FlightInfo] } func (s *statement) closePreparedStatement() error { var header, trailer metadata.MD err := s.prepared.Close(metadata.NewOutgoingContext(context.Background(), s.hdrs), grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) return adbcFromFlightStatusWithDetails(err, header, trailer, "ClosePreparedStatement") } func (s *statement) clearIncrementalQuery() error { // retryDescriptor != nil ==> query is in progress if s.incrementalState != nil { if !s.incrementalState.complete && s.incrementalState.retryDescriptor != nil { return adbc.Error{ Code: adbc.StatusInvalidState, Msg: "[Flight SQL] Cannot disable incremental execution while a query is in progress, finish execution first", } } s.incrementalState = &incrementalState{} s.lastInfo.Store(nil) } return nil } func (s *statement) poll(ctx context.Context, opts ...grpc.CallOption) (*flight.PollInfo, error) { if s.prepared != nil { return s.prepared.ExecutePoll(ctx, s.incrementalState.retryDescriptor, opts...) } return s.query.poll(ctx, s.cnxn, s.incrementalState.retryDescriptor, opts...) } // Close releases any relevant resources associated with this statement // and closes it (particularly if it is a prepared statement). // // A statement instance should not be used after Close is called. func (s *statement) Close() (err error) { if s.prepared != nil { err = s.closePreparedStatement() s.prepared = nil } if s.cnxn == nil { return adbc.Error{ Msg: "[Flight SQL Statement] cannot close already closed statement", Code: adbc.StatusInvalidState, } } s.clientCache = nil s.cnxn = nil return err } func (s *statement) GetOption(key string) (string, error) { switch key { case OptionStatementSubstraitVersion: return s.query.substraitVersion, nil case OptionTimeoutFetch: return s.timeouts.fetchTimeout.String(), nil case OptionTimeoutQuery: return s.timeouts.queryTimeout.String(), nil case OptionTimeoutUpdate: return s.timeouts.updateTimeout.String(), nil case adbc.OptionKeyIncremental: if s.incrementalState != nil { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil } if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) values := s.hdrs.Get(name) if len(values) > 0 { return values[0], nil } } return "", adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } func (s *statement) GetOptionBytes(key string) ([]byte, error) { switch key { case OptionLastFlightInfo: info := s.lastInfo.Load() if info == nil { return []byte{}, nil } serialized, err := proto.Marshal(info) if err != nil { return nil, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Could not serialize result for '%s': %s", key, err.Error()), Code: adbc.StatusInternal, } } return serialized, nil } return nil, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } func (s *statement) GetOptionInt(key string) (int64, error) { switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: val, err := s.GetOptionDouble(key) if err != nil { return 0, err } return int64(val), nil } return 0, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } func (s *statement) GetOptionDouble(key string) (float64, error) { switch key { case OptionTimeoutFetch: return s.timeouts.fetchTimeout.Seconds(), nil case OptionTimeoutQuery: return s.timeouts.queryTimeout.Seconds(), nil case OptionTimeoutUpdate: return s.timeouts.updateTimeout.Seconds(), nil case adbc.OptionKeyProgress: return atomicLoadFloat64(&s.progress), nil case adbc.OptionKeyMaxProgress: return 1.0, nil } return 0, adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), Code: adbc.StatusNotFound, } } // SetOption sets a string option on this statement func (s *statement) SetOption(key string, val string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) if val == "" { s.hdrs.Delete(name) } else { s.hdrs.Append(name, val) } return nil } switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: return s.timeouts.setTimeoutString(key, val) case OptionStatementQueueSize: var err error var size int if size, err = strconv.Atoi(val); err != nil { return adbc.Error{ Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), Code: adbc.StatusInvalidArgument, } } return s.SetOptionInt(key, int64(size)) case OptionStatementSubstraitVersion: s.query.substraitVersion = val case adbc.OptionKeyIncremental: switch val { case adbc.OptionValueEnabled: if err := s.clearIncrementalQuery(); err != nil { return err } s.incrementalState = &incrementalState{} case adbc.OptionValueDisabled: if err := s.clearIncrementalQuery(); err != nil { return err } s.incrementalState = nil default: return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Invalid statement option value %s=%s", key, val), Code: adbc.StatusInvalidArgument, } } default: return adbc.Error{ Msg: "[Flight SQL] Unknown statement option '" + key + "'", Code: adbc.StatusNotImplemented, } } return nil } func (s *statement) SetOptionBytes(key string, value []byte) error { return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), Code: adbc.StatusNotImplemented, } } func (s *statement) SetOptionInt(key string, value int64) error { switch key { case OptionStatementQueueSize: if value <= 0 { return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), Code: adbc.StatusInvalidArgument, } } s.queueSize = int(value) return nil } return s.SetOptionDouble(key, float64(value)) } func (s *statement) SetOptionDouble(key string, value float64) error { switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: return s.timeouts.setTimeout(key, value) } return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), Code: adbc.StatusNotImplemented, } } // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. // For queries expected to be executed repeatedly, Prepare should be // called before execution. func (s *statement) SetSqlQuery(query string) error { if s.prepared != nil { if err := s.closePreparedStatement(); err != nil { return err } s.prepared = nil } if err := s.clearIncrementalQuery(); err != nil { return err } s.query.setSqlQuery(query) return nil } // ExecuteQuery executes the current query or prepared statement // and returnes a RecordReader for the results along with the number // of rows affected if known, otherwise it will be -1. // // This invalidates any prior result sets on this statement. func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, nrec int64, err error) { if err := s.clearIncrementalQuery(); err != nil { return nil, -1, err } ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var info *flight.FlightInfo var header, trailer metadata.MD opts := append([]grpc.CallOption{}, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if s.prepared != nil { info, err = s.prepared.Execute(ctx, opts...) } else { info, err = s.query.execute(ctx, s.cnxn, opts...) } if err != nil { return nil, -1, adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteQuery") } nrec = info.TotalRecords rdr, err = newRecordReader(ctx, s.alloc, s.cnxn.cl, info, s.clientCache, s.queueSize, s.timeouts) return } // ExecuteUpdate executes a statement that does not generate a result // set. It returns the number of rows affected if known, otherwise -1. func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { if err := s.clearIncrementalQuery(); err != nil { return -1, err } ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var header, trailer metadata.MD opts := append([]grpc.CallOption{}, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if s.prepared != nil { n, err = s.prepared.ExecuteUpdate(ctx, opts...) } else { n, err = s.query.executeUpdate(ctx, s.cnxn, opts...) } if err != nil { err = adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteQuery") } return } // Prepare turns this statement into a prepared statement to be executed // multiple times. This invalidates any prior result sets. func (s *statement) Prepare(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var header, trailer metadata.MD prep, err := s.query.prepare(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if err != nil { return adbcFromFlightStatusWithDetails(err, header, trailer, "Prepare") } s.prepared = prep return nil } // SetSubstraitPlan allows setting a serialized Substrait execution // plan into the query or for querying Substrait-related metadata. // // Drivers are not required to support both SQL and Substrait semantics. // If they do, it may be via converting between representations internally. // // Like SetSqlQuery, after this is called the query can be executed // using any of the Execute methods. If the query is expected to be // executed repeatedly, Prepare should be called first on the statement. func (s *statement) SetSubstraitPlan(plan []byte) error { if s.prepared != nil { if err := s.closePreparedStatement(); err != nil { return err } s.prepared = nil } if err := s.clearIncrementalQuery(); err != nil { return err } s.query.setSubstraitPlan(plan) return nil } // Bind uses an arrow record batch to bind parameters to the query. // // This can be used for bulk inserts or for prepared statements. // The driver will call release on the passed in Record when it is done, // but it may not do this until the statement is closed or another // record is bound. func (s *statement) Bind(_ context.Context, values arrow.Record) error { // TODO: handle bulk insert situation if s.prepared == nil { return adbc.Error{ Msg: "[Flight SQL Statement] must call Prepare before calling Bind", Code: adbc.StatusInvalidState} } // calls retain s.prepared.SetParameters(values) return nil } // BindStream uses a record batch stream to bind parameters for this // query. This can be used for bulk inserts or prepared statements. // // The driver will call Release on the record reader, but may not do this // until Close is called. func (s *statement) BindStream(_ context.Context, stream array.RecordReader) error { if s.prepared == nil { return adbc.Error{ Msg: "[Flight SQL Statement] must call Prepare before calling Bind", Code: adbc.StatusInvalidState} } // calls retain s.prepared.SetRecordReader(stream) return nil } // GetParameterSchema returns an Arrow schema representation of // the expected parameters to be bound. // // This retrieves an Arrow Schema describing the number, names, and // types of the parameters in a parameterized statement. The fields // of the schema should be in order of the ordinal position of the // parameters; named parameters should appear only once. // // If the parameter does not have a name, or a name cannot be determined, // the name of the corresponding field in the schema will be an empty // string. If the type cannot be determined, the type of the corresponding // field will be NA (NullType). // // This should be called only after calling Prepare. // // This should return an error with StatusNotImplemented if the schema // cannot be determined. func (s *statement) GetParameterSchema() (*arrow.Schema, error) { if s.prepared == nil { return nil, adbc.Error{ Msg: "[Flight SQL Statement] must call Prepare before GetParameterSchema", Code: adbc.StatusInvalidState, } } ret := s.prepared.ParameterSchema() if ret == nil { return nil, adbc.Error{Code: adbc.StatusNotImplemented} } return ret, nil } // ExecutePartitions executes the current statement and gets the results // as a partitioned result set. // // It returns the Schema of the result set (if available, nil otherwise), // the collection of partition descriptors and the number of rows affected, // if known. If unknown, the number of rows affected will be -1. // // If the driver does not support partitioned results, this will return // an error with a StatusNotImplemented code. func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.Partitions, int64, error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) var ( info *flight.FlightInfo poll *flight.PollInfo out adbc.Partitions sc *arrow.Schema err error ) var header, trailer metadata.MD if s.incrementalState != nil { if s.incrementalState.complete { schema := s.incrementalState.schema totalRecords := s.incrementalState.previousInfo.TotalRecords // Reset the statement for reuse s.incrementalState = &incrementalState{} atomicStoreFloat64(&s.progress, 0.0) s.lastInfo.Store(nil) return schema, adbc.Partitions{}, totalRecords, nil } backoff := 100 * time.Millisecond for { // Keep polling until the query completes or we get new partitions poll, err = s.poll(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if err != nil { break } info = poll.GetInfo() if info == nil { // The server is misbehaving // XXX: should we also issue a query cancellation? s.incrementalState = &incrementalState{} atomicStoreFloat64(&s.progress, 0.0) return nil, adbc.Partitions{}, -1, adbc.Error{ Msg: "[Flight SQL] Server returned a PollInfo with no FlightInfo", Code: adbc.StatusInternal, } } info = proto.Clone(info).(*flight.FlightInfo) // We only return the new endpoints each time if s.incrementalState.previousInfo != nil { offset := len(s.incrementalState.previousInfo.Endpoint) if offset >= len(info.Endpoint) { info.Endpoint = []*flight.FlightEndpoint{} } else { info.Endpoint = info.Endpoint[offset:] } } s.incrementalState.previousInfo = poll.GetInfo() s.incrementalState.retryDescriptor = poll.GetFlightDescriptor() atomicStoreFloat64(&s.progress, poll.GetProgress()) s.lastInfo.Store(poll.GetInfo()) if s.incrementalState.retryDescriptor == nil { // Query is finished s.incrementalState.complete = true break } else if len(info.Endpoint) > 0 { // Query made progress break } // Back off before next poll time.Sleep(backoff) backoff *= 2 if backoff > 5000*time.Millisecond { backoff = 5000 * time.Millisecond } } // Special case: the query completed but there were no new endpoints. We // return 0 new partitions, and also reset the statement (because // returning 0 partitions implies completion) if s.incrementalState.complete && len(info.Endpoint) == 0 { s.incrementalState = &incrementalState{} atomicStoreFloat64(&s.progress, 0.0) s.lastInfo.Store(nil) } } else if s.prepared != nil { info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) } else { info, err = s.query.execute(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) } if err != nil { return nil, out, -1, adbcFromFlightStatusWithDetails(err, header, trailer, "ExecutePartitions") } if len(info.Schema) > 0 { sc, err = flight.DeserializeSchema(info.Schema, s.alloc) if err != nil { return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions: could not deserialize FlightInfo schema:") } } if s.incrementalState != nil { s.incrementalState.schema = sc } out.NumPartitions = uint64(len(info.Endpoint)) out.PartitionIDs = make([][]byte, out.NumPartitions) for i, e := range info.Endpoint { partition := proto.Clone(info).(*flight.FlightInfo) partition.Endpoint = []*flight.FlightEndpoint{e} data, err := proto.Marshal(partition) if err != nil { return sc, out, -1, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInternal, } } out.PartitionIDs[i] = data } return sc, out, info.TotalRecords, nil } // ExecuteSchema gets the schema of the result set of a query without executing it. func (s *statement) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) if s.prepared != nil { schema = s.prepared.DatasetSchema() if schema == nil { err = adbc.Error{ Msg: "[Flight SQL Statement] Database server did not provide schema for prepared statement", Code: adbc.StatusNotImplemented, } } return } var header, trailer metadata.MD schema, err = s.query.executeSchema(ctx, s.cnxn, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if err != nil { err = adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteSchema") } return }