go/adbc/driver/flightsql/flightsql_database.go (483 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package flightsql import ( "context" "crypto/tls" "crypto/x509" "fmt" "net/url" "strconv" "strings" "sync" "time" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/bluele/gcache" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" ) type dbDialOpts struct { opts []grpc.DialOption maxMsgSize int authority string } func (d *dbDialOpts) rebuild() { d.opts = []grpc.DialOption{ grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(d.maxMsgSize), grpc.MaxCallSendMsgSize(d.maxMsgSize)), } if d.authority != "" { d.opts = append(d.opts, grpc.WithAuthority(d.authority)) } } type databaseImpl struct { driverbase.DatabaseImplBase uri *url.URL creds credentials.TransportCredentials user, pass string hdrs metadata.MD timeout timeoutOption dialOpts dbDialOpts enableCookies bool options map[string]string userDialOpts []grpc.DialOption oauthToken credentials.PerRPCCredentials } func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { var tlsConfig tls.Config for k, v := range cnOptions { d.options[k] = v } if authority, ok := cnOptions[OptionAuthority]; ok { d.dialOpts.authority = authority delete(cnOptions, OptionAuthority) } mtlsCert := cnOptions[OptionMTLSCertChain] mtlsKey := cnOptions[OptionMTLSPrivateKey] switch { case mtlsCert != "" && mtlsKey != "": cert, err := tls.X509KeyPair([]byte(mtlsCert), []byte(mtlsKey)) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("Invalid mTLS certificate: %#v", err), Code: adbc.StatusInvalidArgument, } } tlsConfig.Certificates = []tls.Certificate{cert} delete(cnOptions, OptionMTLSCertChain) delete(cnOptions, OptionMTLSPrivateKey) case mtlsCert != "": return adbc.Error{ Msg: fmt.Sprintf("Must provide both '%s' and '%s', only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey, OptionMTLSCertChain), Code: adbc.StatusInvalidArgument, } case mtlsKey != "": return adbc.Error{ Msg: fmt.Sprintf("Must provide both '%s' and '%s', only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey, OptionMTLSPrivateKey), Code: adbc.StatusInvalidArgument, } } if hostname, ok := cnOptions[OptionSSLOverrideHostname]; ok { tlsConfig.ServerName = hostname delete(cnOptions, OptionSSLOverrideHostname) } if val, ok := cnOptions[OptionSSLSkipVerify]; ok { switch val { case adbc.OptionValueEnabled: tlsConfig.InsecureSkipVerify = true case adbc.OptionValueDisabled: tlsConfig.InsecureSkipVerify = false default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, val), Code: adbc.StatusInvalidArgument, } } delete(cnOptions, OptionSSLSkipVerify) } if cert, ok := cnOptions[OptionSSLRootCerts]; ok { cp := x509.NewCertPool() if !cp.AppendCertsFromPEM([]byte(cert)) { return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': failed to append certificates", OptionSSLRootCerts), Code: adbc.StatusInvalidArgument, } } tlsConfig.RootCAs = cp delete(cnOptions, OptionSSLRootCerts) } d.creds = credentials.NewTLS(&tlsConfig) if auth, ok := cnOptions[OptionAuthorizationHeader]; ok { d.hdrs.Set("authorization", auth) delete(cnOptions, OptionAuthorizationHeader) } const authConflictError = "Authentication conflict: Use either Authorization header OR username/password parameter" if u, ok := cnOptions[adbc.OptionKeyUsername]; ok { if d.hdrs.Len() > 0 { return adbc.Error{ Msg: authConflictError, Code: adbc.StatusInvalidArgument, } } d.user = u delete(cnOptions, adbc.OptionKeyUsername) } if p, ok := cnOptions[adbc.OptionKeyPassword]; ok { if d.hdrs.Len() > 0 { return adbc.Error{ Msg: authConflictError, Code: adbc.StatusInvalidArgument, } } d.pass = p delete(cnOptions, adbc.OptionKeyPassword) } if flow, ok := cnOptions[OptionKeyOauthFlow]; ok { if d.hdrs.Len() > 0 { return adbc.Error{ Msg: authConflictError, Code: adbc.StatusInvalidArgument, } } var err error switch flow { case ClientCredentials: d.oauthToken, err = newClientCredentials(cnOptions) case TokenExchange: d.oauthToken, err = newTokenExchangeFlow(cnOptions) default: return adbc.Error{ Msg: fmt.Sprintf("oauth flow not implemented: %s", flow), Code: adbc.StatusNotImplemented, } } if err != nil { return err } delete(cnOptions, OptionKeyOauthFlow) } var err error if tv, ok := cnOptions[OptionTimeoutFetch]; ok { if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); err != nil { return err } delete(cnOptions, OptionTimeoutFetch) } if tv, ok := cnOptions[OptionTimeoutQuery]; ok { if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv); err != nil { return err } delete(cnOptions, OptionTimeoutQuery) } if tv, ok := cnOptions[OptionTimeoutUpdate]; ok { if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv); err != nil { return err } delete(cnOptions, OptionTimeoutUpdate) } if tv, ok := cnOptions[OptionTimeoutConnect]; ok { if err = d.timeout.setTimeoutString(OptionTimeoutConnect, tv); err != nil { return err } delete(cnOptions, OptionTimeoutConnect) } // gRPC deprecated this and explicitly recommends against it delete(cnOptions, OptionWithBlock) if val, ok := cnOptions[OptionWithMaxMsgSize]; ok { var err error var size int if size, err = strconv.Atoi(val); err != nil { return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s' is not a positive integer", OptionWithMaxMsgSize, val), Code: adbc.StatusInvalidArgument, } } else if size <= 0 { return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s' is not a positive integer", OptionWithMaxMsgSize, val), Code: adbc.StatusInvalidArgument, } } d.dialOpts.maxMsgSize = size delete(cnOptions, OptionWithMaxMsgSize) } d.dialOpts.rebuild() if val, ok := cnOptions[OptionCookieMiddleware]; ok { switch val { case adbc.OptionValueEnabled: d.enableCookies = true case adbc.OptionValueDisabled: d.enableCookies = false default: return d.ErrorHelper.Errorf(adbc.StatusInvalidArgument, "Invalid value for database option '%s': '%s'", OptionCookieMiddleware, val) } delete(cnOptions, OptionCookieMiddleware) } for key, val := range cnOptions { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { d.hdrs.Append(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix), val) continue } return d.ErrorHelper.Errorf(adbc.StatusInvalidArgument, "[Flight SQL] Unknown database option '%s'", key) } return nil } func (d *databaseImpl) GetOption(key string) (string, error) { switch key { case OptionTimeoutFetch: return d.timeout.fetchTimeout.String(), nil case OptionTimeoutQuery: return d.timeout.queryTimeout.String(), nil case OptionTimeoutUpdate: return d.timeout.updateTimeout.String(), nil case OptionTimeoutConnect: return d.timeout.connectTimeout.String(), nil } if val, ok := d.options[key]; ok { return val, nil } return d.DatabaseImplBase.GetOption(key) } func (d *databaseImpl) GetOptionInt(key string) (int64, error) { switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: fallthrough case OptionTimeoutConnect: val, err := d.GetOptionDouble(key) if err != nil { return 0, err } return int64(val), nil } return d.DatabaseImplBase.GetOptionInt(key) } func (d *databaseImpl) GetOptionDouble(key string) (float64, error) { switch key { case OptionTimeoutFetch: return d.timeout.fetchTimeout.Seconds(), nil case OptionTimeoutQuery: return d.timeout.queryTimeout.Seconds(), nil case OptionTimeoutUpdate: return d.timeout.updateTimeout.Seconds(), nil case OptionTimeoutConnect: return d.timeout.connectTimeout.Seconds(), nil } return d.DatabaseImplBase.GetOptionDouble(key) } func (d *databaseImpl) SetOption(key, value string) error { // We can't change most options post-init switch key { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate, OptionTimeoutConnect: return d.timeout.setTimeoutString(key, value) } if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix), value) } return d.DatabaseImplBase.SetOption(key, value) } func (d *databaseImpl) SetOptionInt(key string, value int64) error { switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: fallthrough case OptionTimeoutConnect: return d.timeout.setTimeout(key, float64(value)) } return d.DatabaseImplBase.SetOptionInt(key, value) } func (d *databaseImpl) SetOptionDouble(key string, value float64) error { switch key { case OptionTimeoutFetch: fallthrough case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: fallthrough case OptionTimeoutConnect: return d.timeout.setTimeout(key, value) } return d.DatabaseImplBase.SetOptionDouble(key, value) } func (d *databaseImpl) Close() error { return nil } func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddle *bearerAuthMiddleware, cookies flight.CookieMiddleware) (*flightsql.Client, error) { middleware := []flight.ClientMiddleware{ { Unary: makeUnaryLoggingInterceptor(d.Logger), Stream: makeStreamLoggingInterceptor(d.Logger), }, flight.CreateClientMiddleware(authMiddle), { Unary: unaryTimeoutInterceptor, Stream: streamTimeoutInterceptor, }, } if d.enableCookies { middleware = append(middleware, flight.CreateClientMiddleware(cookies)) } uri, err := url.Parse(loc) if err != nil { return nil, adbc.Error{Msg: fmt.Sprintf("Invalid URI '%s': %s", loc, err), Code: adbc.StatusInvalidArgument} } creds := d.creds target := uri.Host switch uri.Scheme { case "grpc", "grpc+tcp": creds = insecure.NewCredentials() case "grpc+unix": creds = insecure.NewCredentials() target = "unix:" + uri.Path } dv, _ := d.DriverInfo.GetInfoForInfoCode(adbc.InfoDriverVersion) driverVersion := dv.(string) dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL Driver "+driverVersion)) dialOpts = append(dialOpts, d.userDialOpts...) if d.oauthToken != nil { dialOpts = append(dialOpts, grpc.WithPerRPCCredentials(d.oauthToken)) } d.Logger.DebugContext(ctx, "new client", "location", loc) cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...) if err != nil { return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusIO, } } cl.Alloc = d.Alloc // Authorization header is already set, continue if len(authMiddle.hdrs.Get("authorization")) > 0 { d.Logger.DebugContext(ctx, "reusing auth token", "location", loc) return cl, nil } var authValue string if d.user != "" || d.pass != "" { var header, trailer metadata.MD ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "AuthenticateBasicToken") } if md, ok := metadata.FromOutgoingContext(ctx); ok { authValue = md.Get("Authorization")[0] } } if authValue != "" { authMiddle.SetHeader(authValue) } return cl, nil } type support struct { transactions bool } func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()} var cookies flight.CookieMiddleware if d.enableCookies { cookies = flight.NewCookieMiddleware() } cl, err := getFlightClient(ctx, d.uri.String(), d, authMiddle, cookies) if err != nil { return nil, err } cache := gcache.New(20).LRU(). Expiration(5 * time.Minute). LoaderFunc(func(loc interface{}) (interface{}, error) { uri, ok := loc.(string) if !ok { return nil, adbc.Error{Msg: fmt.Sprintf("Location must be a string, got %#v", uri), Code: adbc.StatusInternal} } var cookieMiddleware flight.CookieMiddleware // if cookies are enabled, start by cloning the existing cookies if d.enableCookies { cookieMiddleware = cookies.Clone() } // use the existing auth token if there is one cl, err := getFlightClient(context.Background(), uri, d, &bearerAuthMiddleware{hdrs: authMiddle.hdrs.Copy()}, cookieMiddleware) if err != nil { return nil, err } cl.Alloc = d.Alloc return cl, nil }). EvictedFunc(func(_, client interface{}) { conn := client.(*flightsql.Client) err := conn.Close() if err != nil { d.Logger.Debug("failed to close client", "error", err.Error()) } }).Build() var cnxnSupport support info, err := cl.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, d.timeout) // ignore this if it fails if err == nil { const int32code = 3 for _, endpoint := range info.Endpoint { rdr, err := doGet(ctx, cl, endpoint, cache, d.timeout) if err != nil { continue } defer rdr.Release() for rdr.Next() { rec := rdr.Record() codes := rec.Column(0).(*array.Uint32) values := rec.Column(1).(*array.DenseUnion) int32Value := values.Field(int32code).(*array.Int32) for i := 0; i < int(rec.NumRows()); i++ { switch codes.Value(i) { case uint32(flightsql.SqlInfoFlightSqlServerTransaction): if values.TypeCode(i) != int32code { continue } idx := values.ValueOffset(i) if !int32Value.IsValid(int(idx)) { continue } value := int32Value.Value(int(idx)) cnxnSupport.transactions = value == int32(flightsql.SqlTransactionTransaction) || value == int32(flightsql.SqlTransactionSavepoint) } } } } } conn := &connectionImpl{ cl: cl, db: d, clientCache: cache, hdrs: make(metadata.MD), timeouts: d.timeout, supportInfo: cnxnSupport, ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), } return driverbase.NewConnectionBuilder(conn). WithDriverInfoPreparer(conn). WithAutocommitSetter(conn). WithCurrentNamespacer(conn). Connection(), nil } type bearerAuthMiddleware struct { mutex sync.RWMutex hdrs metadata.MD } func (b *bearerAuthMiddleware) StartCall(ctx context.Context) context.Context { md, _ := metadata.FromOutgoingContext(ctx) b.mutex.RLock() defer b.mutex.RUnlock() return metadata.NewOutgoingContext(ctx, metadata.Join(md, b.hdrs)) } func (b *bearerAuthMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) { // apache/arrow-adbc#584 headers := md.Get("authorization") if len(headers) > 0 { b.mutex.Lock() defer b.mutex.Unlock() b.hdrs.Set("authorization", headers...) } } func (b *bearerAuthMiddleware) SetHeader(authValue string) { b.mutex.Lock() defer b.mutex.Unlock() b.hdrs.Set("authorization", authValue) }