go/adbc/driver/snowflake/snowflake_database.go (442 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package snowflake import ( "context" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "net/url" "os" "strconv" "strings" "time" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/snowflakedb/gosnowflake" "github.com/youmark/pkcs8" ) var ( drv = gosnowflake.SnowflakeDriver{} authTypeMap = map[string]gosnowflake.AuthType{ OptionValueAuthSnowflake: gosnowflake.AuthTypeSnowflake, OptionValueAuthOAuth: gosnowflake.AuthTypeOAuth, OptionValueAuthExternalBrowser: gosnowflake.AuthTypeExternalBrowser, OptionValueAuthOkta: gosnowflake.AuthTypeOkta, OptionValueAuthJwt: gosnowflake.AuthTypeJwt, OptionValueAuthUserPassMFA: gosnowflake.AuthTypeUsernamePasswordMFA, } ) type databaseImpl struct { driverbase.DatabaseImplBase cfg *gosnowflake.Config useHighPrecision bool defaultAppName string } func (d *databaseImpl) GetOption(key string) (string, error) { switch key { case adbc.OptionKeyUsername: return d.cfg.User, nil case adbc.OptionKeyPassword: return d.cfg.Password, nil case OptionDatabase: return d.cfg.Database, nil case OptionSchema: return d.cfg.Schema, nil case OptionWarehouse: return d.cfg.Warehouse, nil case OptionRole: return d.cfg.Role, nil case OptionRegion: return d.cfg.Region, nil case OptionAccount: return d.cfg.Account, nil case OptionProtocol: return d.cfg.Protocol, nil case OptionHost: return d.cfg.Host, nil case OptionPort: return strconv.Itoa(d.cfg.Port), nil case OptionAuthType: return d.cfg.Authenticator.String(), nil case OptionLoginTimeout: return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil case OptionRequestTimeout: return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil case OptionJwtExpireTimeout: return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil case OptionClientTimeout: return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil case OptionApplicationName: return d.cfg.Application, nil case OptionSSLSkipVerify: if d.cfg.DisableOCSPChecks { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil case OptionOCSPFailOpenMode: return strconv.FormatUint(uint64(d.cfg.OCSPFailOpen), 10), nil case OptionAuthToken: return d.cfg.Token, nil case OptionAuthOktaUrl: return d.cfg.OktaURL.String(), nil case OptionKeepSessionAlive: if d.cfg.KeepSessionAlive { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil case OptionDisableTelemetry: if d.cfg.DisableTelemetry { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil case OptionClientRequestMFAToken: if d.cfg.ClientRequestMfaToken == gosnowflake.ConfigBoolTrue { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil case OptionClientStoreTempCred: if d.cfg.ClientStoreTemporaryCredential == gosnowflake.ConfigBoolTrue { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil case OptionLogTracing: return d.cfg.Tracing, nil case OptionClientConfigFile: return d.cfg.ClientConfigFile, nil case OptionUseHighPrecision: if d.useHighPrecision { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil default: val, ok := d.cfg.Params[key] if ok { return *val, nil } } return d.DatabaseImplBase.GetOption(key) } func (d *databaseImpl) SetOption(key string, value string) error { return d.SetOptionInternal(key, value, nil) } func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { uri, ok := cnOptions[adbc.OptionKeyURI] if ok { cfg, err := gosnowflake.ParseDSN(uri) if err != nil { return errToAdbcErr(adbc.StatusInvalidArgument, err) } d.cfg = cfg delete(cnOptions, adbc.OptionKeyURI) } else { d.cfg = &gosnowflake.Config{ Params: make(map[string]*string), } } // set default application name to track // unless user overrides it d.cfg.Application = d.defaultAppName for k, v := range cnOptions { v := v // copy into loop scope err := d.SetOptionInternal(k, v, &cnOptions) if err != nil { return err } } return nil } // SetOptionInternal sets the option for the database. // // cnOptions is nil if the option is being set post-initialiation. func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[string]string) error { var err error var ok bool switch k { case adbc.OptionKeyUsername: d.cfg.User = v case adbc.OptionKeyPassword: d.cfg.Password = v case OptionDatabase: d.cfg.Database = v case OptionSchema: d.cfg.Schema = v case OptionWarehouse: d.cfg.Warehouse = v case OptionRole: d.cfg.Role = v case OptionRegion: d.cfg.Region = v case OptionAccount: d.cfg.Account = v case OptionProtocol: d.cfg.Protocol = v case OptionHost: d.cfg.Host = v case OptionPort: d.cfg.Port, err = strconv.Atoi(v) if err != nil { return adbc.Error{ Msg: "error encountered parsing Port option: " + err.Error(), Code: adbc.StatusInvalidArgument, } } case OptionAuthType: d.cfg.Authenticator, ok = authTypeMap[v] if !ok { return adbc.Error{ Msg: "invalid option value for " + OptionAuthType + ": '" + v + "'", Code: adbc.StatusInvalidArgument, } } case OptionLoginTimeout: dur, err := time.ParseDuration(v) if err != nil { return adbc.Error{ Msg: "could not parse duration for '" + OptionLoginTimeout + "': " + err.Error(), Code: adbc.StatusInvalidArgument, } } if dur < 0 { dur = -dur } d.cfg.LoginTimeout = dur case OptionRequestTimeout: dur, err := time.ParseDuration(v) if err != nil { return adbc.Error{ Msg: "could not parse duration for '" + OptionRequestTimeout + "': " + err.Error(), Code: adbc.StatusInvalidArgument, } } if dur < 0 { dur = -dur } d.cfg.RequestTimeout = dur case OptionJwtExpireTimeout: dur, err := time.ParseDuration(v) if err != nil { return adbc.Error{ Msg: "could not parse duration for '" + OptionJwtExpireTimeout + "': " + err.Error(), Code: adbc.StatusInvalidArgument, } } if dur < 0 { dur = -dur } d.cfg.JWTExpireTimeout = dur case OptionClientTimeout: dur, err := time.ParseDuration(v) if err != nil { return adbc.Error{ Msg: "could not parse duration for '" + OptionClientTimeout + "': " + err.Error(), Code: adbc.StatusInvalidArgument, } } if dur < 0 { dur = -dur } d.cfg.ClientTimeout = dur case OptionApplicationName: if !strings.HasPrefix(v, "[ADBC]") { v = d.defaultAppName + v } d.cfg.Application = v case OptionSSLSkipVerify: switch v { case adbc.OptionValueEnabled: d.cfg.DisableOCSPChecks = true case adbc.OptionValueDisabled: d.cfg.DisableOCSPChecks = false default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v), Code: adbc.StatusInvalidArgument, } } case OptionOCSPFailOpenMode: switch v { case adbc.OptionValueEnabled: d.cfg.OCSPFailOpen = gosnowflake.OCSPFailOpenTrue case adbc.OptionValueDisabled: d.cfg.OCSPFailOpen = gosnowflake.OCSPFailOpenFalse default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v), Code: adbc.StatusInvalidArgument, } } case OptionAuthToken: d.cfg.Token = v case OptionAuthOktaUrl: d.cfg.OktaURL, err = url.Parse(v) if err != nil { return adbc.Error{ Msg: fmt.Sprintf("error parsing URL for database option '%s': '%s'", k, v), Code: adbc.StatusInvalidArgument, } } case OptionKeepSessionAlive: switch v { case adbc.OptionValueEnabled: d.cfg.KeepSessionAlive = true case adbc.OptionValueDisabled: d.cfg.KeepSessionAlive = false default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v), Code: adbc.StatusInvalidArgument, } } case OptionDisableTelemetry: switch v { case adbc.OptionValueEnabled: d.cfg.DisableTelemetry = true case adbc.OptionValueDisabled: d.cfg.DisableTelemetry = false default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v), Code: adbc.StatusInvalidArgument, } } case OptionJwtPrivateKey: data, err := os.ReadFile(v) if err != nil { return adbc.Error{ Msg: "could not read private key file '" + v + "': " + err.Error(), Code: adbc.StatusInvalidArgument, } } var block []byte if strings.Contains(string(data), "PRIVATE KEY") { b, _ := pem.Decode(data) block = b.Bytes } else { block = data } var key *rsa.PrivateKey key, err = x509.ParsePKCS1PrivateKey(block) if err != nil && strings.Contains(err.Error(), "use ParsePKCS8PrivateKey instead") { var pkcs8Key any pkcs8Key, err = x509.ParsePKCS8PrivateKey(block) key, ok = pkcs8Key.(*rsa.PrivateKey) if !ok { err = errors.New("file does not contain an RSA private key") } } if err != nil { return adbc.Error{ Msg: "failed parsing private key file '" + v + "': " + err.Error(), Code: adbc.StatusInvalidArgument, } } d.cfg.PrivateKey = key case OptionJwtPrivateKeyPkcs8Value: block, _ := pem.Decode([]byte(v)) if block == nil { return adbc.Error{ Msg: "Failed to parse PEM block containing the private key", Code: adbc.StatusInvalidArgument, } } var parsedKey any switch block.Type { case "ENCRYPTED PRIVATE KEY": if cnOptions == nil { return adbc.Error{ Msg: "[Snowflake] unable to set private key post initialization", Code: adbc.StatusInvalidArgument, } } passcode, ok := (*cnOptions)[OptionJwtPrivateKeyPkcs8Password] if ok { parsedKey, err = pkcs8.ParsePKCS8PrivateKey(block.Bytes, []byte(passcode)) } else { return adbc.Error{ Msg: OptionJwtPrivateKeyPkcs8Password + " is not configured", Code: adbc.StatusInvalidArgument, } } case "PRIVATE KEY": parsedKey, err = pkcs8.ParsePKCS8PrivateKey(block.Bytes) default: return adbc.Error{ Msg: block.Type + " is not supported", Code: adbc.StatusInvalidArgument, } } if err != nil { return adbc.Error{ Msg: "[Snowflake] failed parsing PKCS8 private key: " + err.Error(), Code: adbc.StatusInvalidArgument, } } d.cfg.PrivateKey = parsedKey.(*rsa.PrivateKey) case OptionClientRequestMFAToken: switch v { case adbc.OptionValueEnabled: d.cfg.ClientRequestMfaToken = gosnowflake.ConfigBoolTrue case adbc.OptionValueDisabled: d.cfg.ClientRequestMfaToken = gosnowflake.ConfigBoolFalse default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v), Code: adbc.StatusInvalidArgument, } } case OptionClientStoreTempCred: switch v { case adbc.OptionValueEnabled: d.cfg.ClientStoreTemporaryCredential = gosnowflake.ConfigBoolTrue case adbc.OptionValueDisabled: d.cfg.ClientStoreTemporaryCredential = gosnowflake.ConfigBoolFalse default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v), Code: adbc.StatusInvalidArgument, } } case OptionLogTracing: d.cfg.Tracing = v case OptionClientConfigFile: d.cfg.ClientConfigFile = v case OptionUseHighPrecision: switch v { case adbc.OptionValueEnabled: d.useHighPrecision = true case adbc.OptionValueDisabled: d.useHighPrecision = false default: return adbc.Error{ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionUseHighPrecision, v), Code: adbc.StatusInvalidArgument, } } default: d.cfg.Params[k] = &v } return nil } func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { connector := gosnowflake.NewConnector(drv, *d.cfg) ctx = gosnowflake.WithArrowAllocator( gosnowflake.WithArrowBatches(ctx), d.Alloc) cn, err := connector.Connect(ctx) if err != nil { return nil, errToAdbcErr(adbc.StatusIO, err) } conn := &connectionImpl{ cn: cn.(snowflakeConn), db: d, ctor: connector, // default enable high precision // SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled) to // get Int64/Float64 instead useHighPrecision: d.useHighPrecision, ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), } return driverbase.NewConnectionBuilder(conn). WithAutocommitSetter(conn). WithCurrentNamespacer(conn). WithTableTypeLister(conn). WithDriverInfoPreparer(conn). Connection(), nil } func (d *databaseImpl) Close() error { return nil } var ( _ adbc.PostInitOptions = (*databaseImpl)(nil) )