go/adbc/driver/snowflake/driver.go (146 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package snowflake import ( "errors" "maps" "net/http" "runtime/debug" "strings" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/snowflakedb/gosnowflake" ) const ( OptionDatabase = "adbc.snowflake.sql.db" OptionSchema = "adbc.snowflake.sql.schema" OptionWarehouse = "adbc.snowflake.sql.warehouse" OptionRole = "adbc.snowflake.sql.role" OptionRegion = "adbc.snowflake.sql.region" OptionAccount = "adbc.snowflake.sql.account" OptionProtocol = "adbc.snowflake.sql.uri.protocol" OptionPort = "adbc.snowflake.sql.uri.port" OptionHost = "adbc.snowflake.sql.uri.host" // Specify auth type to use for snowflake connection based on // what is supported by the snowflake driver. Default is // "auth_snowflake" (use OptionValueAuth* consts to specify desired // authentication type). OptionAuthType = "adbc.snowflake.sql.auth_type" // Login retry timeout EXCLUDING network roundtrip and reading http response // use format like http://pkg.go.dev/time#ParseDuration such as // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values // but the absolute value will be used. OptionLoginTimeout = "adbc.snowflake.sql.client_option.login_timeout" // request retry timeout EXCLUDING network roundtrip and reading http response // use format like http://pkg.go.dev/time#ParseDuration such as // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values // but the absolute value will be used. OptionRequestTimeout = "adbc.snowflake.sql.client_option.request_timeout" // JWT expiration after timeout // use format like http://pkg.go.dev/time#ParseDuration such as // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values // but the absolute value will be used. OptionJwtExpireTimeout = "adbc.snowflake.sql.client_option.jwt_expire_timeout" // Timeout for network round trip + reading http response // use format like http://pkg.go.dev/time#ParseDuration such as // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values // but the absolute value will be used. OptionClientTimeout = "adbc.snowflake.sql.client_option.client_timeout" // OptionUseHighPrecision controls the data type used for NUMBER columns // using a FIXED size data type. By default, this is enabled and NUMBER // columns will be returned as Decimal128 types using the indicated // precision and scale of the type. If disabled, then fixed-point data // with a scale of 0 will be returned as Int64 columns, and a non-zero // scale will return a Float64 column. OptionUseHighPrecision = "adbc.snowflake.sql.client_option.use_high_precision" OptionApplicationName = "adbc.snowflake.sql.client_option.app_name" OptionSSLSkipVerify = "adbc.snowflake.sql.client_option.tls_skip_verify" OptionOCSPFailOpenMode = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode" // specify the token to use for OAuth or other forms of authentication OptionAuthToken = "adbc.snowflake.sql.client_option.auth_token" // specify the OKTAUrl to use for OKTA Authentication OptionAuthOktaUrl = "adbc.snowflake.sql.client_option.okta_url" // enable the session to persist even after the connection is closed OptionKeepSessionAlive = "adbc.snowflake.sql.client_option.keep_session_alive" // specify the RSA private key to use to sign the JWT // this should point to a file containing a PKCS1 private key to be // loaded. Commonly encoded in PEM blocks of type "RSA PRIVATE KEY" OptionJwtPrivateKey = "adbc.snowflake.sql.client_option.jwt_private_key" // parses a private key in PKCS #8, ASN.1 DER form. Specify the private key // value without having to load it from the file system. OptionJwtPrivateKeyPkcs8Value = "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value" // a passcode to use with encrypted private keys for JWT authentication OptionJwtPrivateKeyPkcs8Password = "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password" OptionDisableTelemetry = "adbc.snowflake.sql.client_option.disable_telemetry" // snowflake driver logging level OptionLogTracing = "adbc.snowflake.sql.client_option.tracing" // snowflake driver client logging config file OptionClientConfigFile = "adbc.snowflake.sql.client_option.config_file" // When true, the MFA token is cached in the credential manager. True by default // on Windows/OSX, false for Linux OptionClientRequestMFAToken = "adbc.snowflake.sql.client_option.cache_mfa_token" // When true, the ID token is cached in the credential manager. True by default // on Windows/OSX, false for Linux OptionClientStoreTempCred = "adbc.snowflake.sql.client_option.store_temp_creds" // auth types are implemented by the Snowflake driver in gosnowflake // general username password authentication OptionValueAuthSnowflake = "auth_snowflake" // use OAuth authentication for snowflake connection OptionValueAuthOAuth = "auth_oauth" // use an external browser to access a FED and perform SSO auth OptionValueAuthExternalBrowser = "auth_ext_browser" // use a native OKTA URL to perform SSO authentication on Okta OptionValueAuthOkta = "auth_okta" // use a JWT to perform authentication OptionValueAuthJwt = "auth_jwt" // use a username and password with mfa OptionValueAuthUserPassMFA = "auth_mfa" ) var ( infoVendorVersion string ) func init() { if info, ok := debug.ReadBuildInfo(); ok { for _, dep := range info.Deps { switch dep.Path { case "github.com/snowflakedb/gosnowflake": infoVendorVersion = dep.Version } } } // Disable some stray logs // https://github.com/snowflakedb/gosnowflake/pull/1332 _ = gosnowflake.GetLogger().SetLogLevel("warn") } func errToAdbcErr(code adbc.Status, err error) error { if err == nil { return nil } var e adbc.Error if errors.As(err, &e) { e.Code = code return e } var sferr *gosnowflake.SnowflakeError if errors.As(err, &sferr) { var sqlstate [5]byte copy(sqlstate[:], []byte(sferr.SQLState)) if sferr.SQLState == "42S02" { code = adbc.StatusNotFound } return adbc.Error{ Code: code, Msg: sferr.Error(), VendorCode: int32(sferr.Number), SqlState: sqlstate, } } return adbc.Error{ Msg: err.Error(), Code: code, } } func quoteTblName(name string) string { return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\"" } type config struct { *gosnowflake.Config } // Option is a function type to set custom driver configurations. // // It is intended for configurations that cannot be provided from the standard options map, // e.g. the underlying HTTP transporter. type Option func(*config) error // WithTransporter sets the custom transporter to use for the Snowflake connection. // This allows to intercept HTTP requests and responses. func WithTransporter(transporter http.RoundTripper) Option { return func(cfg *config) error { cfg.Transporter = transporter return nil } } // Driver is the Snowflake driver interface. // // It extends the base adbc.Driver to provide additional options // when creating the Snowflake database. type Driver interface { adbc.Driver // NewDatabaseWithOptions creates a new Snowflake database with the provided options. NewDatabaseWithOptions(map[string]string, ...Option) (adbc.Database, error) } var _ Driver = (*driverImpl)(nil) type driverImpl struct { driverbase.DriverImplBase } // NewDriver creates a new Snowflake driver using the given Arrow allocator. func NewDriver(alloc memory.Allocator) Driver { info := driverbase.DefaultDriverInfo("Snowflake") if infoVendorVersion != "" { if err := info.RegisterInfoCode(adbc.InfoVendorVersion, infoVendorVersion); err != nil { panic(err) } } return &driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)} } func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { return d.NewDatabaseWithOptions(opts) } func (d *driverImpl) NewDatabaseWithOptions(opts map[string]string, optFuncs ...Option) (adbc.Database, error) { opts = maps.Clone(opts) dbImplBase := driverbase.NewDatabaseImplBase(&d.DriverImplBase) dv, _ := dbImplBase.DriverInfo.GetInfoForInfoCode(adbc.InfoDriverVersion) driverVersion := dv.(string) defaultAppName := "[ADBC][Go-" + driverVersion + "]" db := &databaseImpl{ DatabaseImplBase: dbImplBase, useHighPrecision: true, defaultAppName: defaultAppName, } if err := db.SetOptions(opts); err != nil { return nil, err } cfg := &config{Config: db.cfg} for _, opt := range optFuncs { if err := opt(cfg); err != nil { return nil, err } } return driverbase.NewDatabase(db), nil }