go/adbc/driver/flightsql/timeouts.go (179 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package flightsql import ( "context" "fmt" "io" "math" "strconv" "strings" "time" "github.com/apache/arrow-adbc/go/adbc" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/metadata" ) type timeoutOption struct { grpc.EmptyCallOption // timeout for DoGet requests fetchTimeout time.Duration // timeout for GetFlightInfo requests queryTimeout time.Duration // timeout for DoPut or DoAction requests updateTimeout time.Duration // timeout for establishing a new connection connectTimeout time.Duration } func (t *timeoutOption) setTimeout(key string, value float64) error { if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 { return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %f: timeouts must be non-negative and finite", key, value), Code: adbc.StatusInvalidArgument, } } timeout := time.Duration(value * float64(time.Second)) switch key { case OptionTimeoutFetch: t.fetchTimeout = timeout case OptionTimeoutQuery: t.queryTimeout = timeout case OptionTimeoutUpdate: t.updateTimeout = timeout case OptionTimeoutConnect: t.connectTimeout = timeout default: return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key), Code: adbc.StatusNotImplemented, } } return nil } func (t *timeoutOption) setTimeoutString(key string, value string) error { timeout, err := strconv.ParseFloat(value, 64) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %s: %s", key, value, err.Error()), Code: adbc.StatusInvalidArgument, } } return t.setTimeout(key, timeout) } func (t *timeoutOption) connectParams() grpc.ConnectParams { return grpc.ConnectParams{ Backoff: backoff.DefaultConfig, MinConnectTimeout: t.connectTimeout, } } func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) { for _, opt := range callOptions { if to, ok := opt.(timeoutOption); ok { var tm time.Duration switch { case strings.HasSuffix(method, "DoGet"): tm = to.fetchTimeout case strings.HasSuffix(method, "GetFlightInfo"): tm = to.queryTimeout case strings.HasSuffix(method, "DoPut") || strings.HasSuffix(method, "DoAction"): tm = to.updateTimeout } return tm, tm > 0 } } return 0, false } func unaryTimeoutInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if tm, ok := getTimeout(method, opts); ok { ctx, cancel := context.WithTimeout(ctx, tm) defer cancel() return invoker(ctx, method, req, reply, cc, opts...) } return invoker(ctx, method, req, reply, cc, opts...) } type streamEventType int const ( receiveEndEvent streamEventType = iota errorEvent ) type streamEvent struct { Type streamEventType Err error } type wrappedClientStream struct { grpc.ClientStream desc *grpc.StreamDesc events chan streamEvent eventsDone chan struct{} } func (w *wrappedClientStream) RecvMsg(m any) error { err := w.ClientStream.RecvMsg(m) switch { case err == nil && !w.desc.ServerStreams: w.sendStreamEvent(receiveEndEvent, nil) case err == io.EOF: w.sendStreamEvent(receiveEndEvent, nil) case err != nil: w.sendStreamEvent(errorEvent, err) } return err } func (w *wrappedClientStream) SendMsg(m any) error { err := w.ClientStream.SendMsg(m) if err != nil { w.sendStreamEvent(errorEvent, err) } return err } func (w *wrappedClientStream) Header() (metadata.MD, error) { md, err := w.ClientStream.Header() if err != nil { w.sendStreamEvent(errorEvent, err) } return md, err } func (w *wrappedClientStream) CloseSend() error { err := w.ClientStream.CloseSend() if err != nil { w.sendStreamEvent(errorEvent, err) } return err } func (w *wrappedClientStream) sendStreamEvent(eventType streamEventType, err error) { select { case <-w.eventsDone: case w.events <- streamEvent{Type: eventType, Err: err}: } } func streamTimeoutInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if tm, ok := getTimeout(method, opts); ok { ctx, cancel := context.WithTimeout(ctx, tm) s, err := streamer(ctx, desc, cc, method, opts...) if err != nil { defer cancel() return s, err } events, eventsDone := make(chan streamEvent), make(chan struct{}) go func() { defer close(eventsDone) defer cancel() for { select { case event := <-events: // split by event type in case we want to add more logging // or even adding in some telemetry in the future. // Errors will already be propagated by the RecvMsg, SendMsg // methods. switch event.Type { case receiveEndEvent: return case errorEvent: return } case <-ctx.Done(): return } } }() stream := &wrappedClientStream{ ClientStream: s, desc: desc, events: events, eventsDone: eventsDone, } return stream, nil } return streamer(ctx, desc, cc, method, opts...) }