go/adbc/driver/snowflake/driver.go (351 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package snowflake
import (
"context"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"net/url"
"os"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/maps"
)
const (
infoDriverName = "ADBC Snowflake Driver - Go"
infoVendorName = "Snowflake"
OptionDatabase = "adbc.snowflake.sql.db"
OptionSchema = "adbc.snowflake.sql.schema"
OptionWarehouse = "adbc.snowflake.sql.warehouse"
OptionRole = "adbc.snowflake.sql.role"
OptionRegion = "adbc.snowflake.sql.region"
OptionAccount = "adbc.snowflake.sql.account"
OptionProtocol = "adbc.snowflake.sql.uri.protocol"
OptionPort = "adbc.snowflake.sql.uri.port"
OptionHost = "adbc.snowflake.sql.uri.host"
// Specify auth type to use for snowflake connection based on
// what is supported by the snowflake driver. Default is
// "auth_snowflake" (use OptionValueAuth* consts to specify desired
// authentication type).
OptionAuthType = "adbc.snowflake.sql.auth_type"
// Login retry timeout EXCLUDING network roundtrip and reading http response
// use format like http://pkg.go.dev/time#ParseDuration such as
// "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
// but the absolute value will be used.
OptionLoginTimeout = "adbc.snowflake.sql.client_option.login_timeout"
// request retry timeout EXCLUDING network roundtrip and reading http response
// use format like http://pkg.go.dev/time#ParseDuration such as
// "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
// but the absolute value will be used.
OptionRequestTimeout = "adbc.snowflake.sql.client_option.request_timeout"
// JWT expiration after timeout
// use format like http://pkg.go.dev/time#ParseDuration such as
// "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
// but the absolute value will be used.
OptionJwtExpireTimeout = "adbc.snowflake.sql.client_option.jwt_expire_timeout"
// Timeout for network round trip + reading http response
// use format like http://pkg.go.dev/time#ParseDuration such as
// "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
// but the absolute value will be used.
OptionClientTimeout = "adbc.snowflake.sql.client_option.client_timeout"
OptionApplicationName = "adbc.snowflake.sql.client_option.app_name"
OptionSSLSkipVerify = "adbc.snowflake.sql.client_option.tls_skip_verify"
OptionOCSPFailOpenMode = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode"
// specify the token to use for OAuth or other forms of authentication
OptionAuthToken = "adbc.snowflake.sql.client_option.auth_token"
// specify the OKTAUrl to use for OKTA Authentication
OptionAuthOktaUrl = "adbc.snowflake.sql.client_option.okta_url"
// enable the session to persist even after the connection is closed
OptionKeepSessionAlive = "adbc.snowflake.sql.client_option.keep_session_alive"
// specify the RSA private key to use to sign the JWT
// this should point to a file containing a PKCS1 private key to be
// loaded. Commonly encoded in PEM blocks of type "RSA PRIVATE KEY"
OptionJwtPrivateKey = "adbc.snowflake.sql.client_option.jwt_private_key"
OptionDisableTelemetry = "adbc.snowflake.sql.client_option.disable_telemetry"
// snowflake driver logging level
OptionLogTracing = "adbc.snowflake.sql.client_option.tracing"
// When true, the MFA token is cached in the credential manager. True by default
// on Windows/OSX, false for Linux
OptionClientRequestMFAToken = "adbc.snowflake.sql.client_option.cache_mfa_token"
// When true, the ID token is cached in the credential manager. True by default
// on Windows/OSX, false for Linux
OptionClientStoreTempCred = "adbc.snowflake.sql.client_option.store_temp_creds"
// auth types are implemented by the Snowflake driver in gosnowflake
// general username password authentication
OptionValueAuthSnowflake = "auth_snowflake"
// use OAuth authentication for snowflake connection
OptionValueAuthOAuth = "auth_oauth"
// use an external browser to access a FED and perform SSO auth
OptionValueAuthExternalBrowser = "auth_ext_browser"
// use a native OKTA URL to perform SSO authentication on Okta
OptionValueAuthOkta = "auth_okta"
// use a JWT to perform authentication
OptionValueAuthJwt = "auth_jwt"
// use a username and password with mfa
OptionValueAuthUserPassMFA = "auth_mfa"
)
var (
infoDriverVersion string
infoDriverArrowVersion string
infoSupportedCodes []adbc.InfoCode
)
func init() {
if info, ok := debug.ReadBuildInfo(); ok {
for _, dep := range info.Deps {
switch {
case dep.Path == "github.com/apache/arrow-adbc/go/adbc/driver/snowflake":
infoDriverVersion = dep.Version
case strings.HasPrefix(dep.Path, "github.com/apache/arrow/go/"):
infoDriverArrowVersion = dep.Version
}
}
}
// XXX: Deps not populated in tests
// https://github.com/golang/go/issues/33976
if infoDriverVersion == "" {
infoDriverVersion = "(unknown or development build)"
}
if infoDriverArrowVersion == "" {
infoDriverArrowVersion = "(unknown or development build)"
}
infoSupportedCodes = []adbc.InfoCode{
adbc.InfoDriverName,
adbc.InfoDriverVersion,
adbc.InfoDriverArrowVersion,
adbc.InfoVendorName,
}
}
func errToAdbcErr(code adbc.Status, err error) error {
if err == nil {
return nil
}
var e adbc.Error
if errors.As(err, &e) {
e.Code = code
return e
}
var sferr *gosnowflake.SnowflakeError
if errors.As(err, &sferr) {
var sqlstate [5]byte
copy(sqlstate[:], []byte(sferr.SQLState))
return adbc.Error{
Code: code,
Msg: sferr.Error(),
VendorCode: int32(sferr.Number),
SqlState: sqlstate,
}
}
return adbc.Error{
Msg: err.Error(),
Code: code,
}
}
type Driver struct {
Alloc memory.Allocator
}
func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
db := &database{alloc: d.Alloc}
opts = maps.Clone(opts)
if db.alloc == nil {
db.alloc = memory.DefaultAllocator
}
return db, db.SetOptions(opts)
}
var (
drv = gosnowflake.SnowflakeDriver{}
authTypeMap = map[string]gosnowflake.AuthType{
OptionValueAuthSnowflake: gosnowflake.AuthTypeSnowflake,
OptionValueAuthOAuth: gosnowflake.AuthTypeOAuth,
OptionValueAuthExternalBrowser: gosnowflake.AuthTypeExternalBrowser,
OptionValueAuthOkta: gosnowflake.AuthTypeOkta,
OptionValueAuthJwt: gosnowflake.AuthTypeJwt,
OptionValueAuthUserPassMFA: gosnowflake.AuthTypeUsernamePasswordMFA,
}
)
type database struct {
cfg *gosnowflake.Config
alloc memory.Allocator
}
func (d *database) SetOptions(cnOptions map[string]string) error {
uri, ok := cnOptions[adbc.OptionKeyURI]
if ok {
cfg, err := gosnowflake.ParseDSN(uri)
if err != nil {
return errToAdbcErr(adbc.StatusInvalidArgument, err)
}
d.cfg = cfg
delete(cnOptions, adbc.OptionKeyURI)
} else {
d.cfg = &gosnowflake.Config{
Params: make(map[string]*string),
}
}
var err error
for k, v := range cnOptions {
switch k {
case adbc.OptionKeyUsername:
d.cfg.User = v
case adbc.OptionKeyPassword:
d.cfg.Password = v
case OptionDatabase:
d.cfg.Database = v
case OptionSchema:
d.cfg.Schema = v
case OptionWarehouse:
d.cfg.Warehouse = v
case OptionRole:
d.cfg.Role = v
case OptionRegion:
d.cfg.Region = v
case OptionAccount:
d.cfg.Account = v
case OptionProtocol:
d.cfg.Protocol = v
case OptionHost:
d.cfg.Host = v
case OptionPort:
d.cfg.Port, err = strconv.Atoi(v)
if err != nil {
return adbc.Error{
Msg: "error encountered parsing Port option: " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
case OptionAuthType:
d.cfg.Authenticator, ok = authTypeMap[v]
if !ok {
return adbc.Error{
Msg: "invalid option value for " + OptionAuthType + ": '" + v + "'",
Code: adbc.StatusInvalidArgument,
}
}
case OptionLoginTimeout:
dur, err := time.ParseDuration(v)
if err != nil {
return adbc.Error{
Msg: "could not parse duration for '" + OptionLoginTimeout + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
if dur < 0 {
dur = -dur
}
d.cfg.LoginTimeout = dur
case OptionRequestTimeout:
dur, err := time.ParseDuration(v)
if err != nil {
return adbc.Error{
Msg: "could not parse duration for '" + OptionRequestTimeout + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
if dur < 0 {
dur = -dur
}
d.cfg.RequestTimeout = dur
case OptionJwtExpireTimeout:
dur, err := time.ParseDuration(v)
if err != nil {
return adbc.Error{
Msg: "could not parse duration for '" + OptionJwtExpireTimeout + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
if dur < 0 {
dur = -dur
}
d.cfg.JWTExpireTimeout = dur
case OptionClientTimeout:
dur, err := time.ParseDuration(v)
if err != nil {
return adbc.Error{
Msg: "could not parse duration for '" + OptionClientTimeout + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
if dur < 0 {
dur = -dur
}
d.cfg.ClientTimeout = dur
case OptionApplicationName:
d.cfg.Application = v
case OptionSSLSkipVerify:
switch v {
case adbc.OptionValueEnabled:
d.cfg.InsecureMode = true
case adbc.OptionValueDisabled:
d.cfg.InsecureMode = false
default:
return adbc.Error{
Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v),
Code: adbc.StatusInvalidArgument,
}
}
case OptionOCSPFailOpenMode:
switch v {
case adbc.OptionValueEnabled:
d.cfg.OCSPFailOpen = gosnowflake.OCSPFailOpenTrue
case adbc.OptionValueDisabled:
d.cfg.OCSPFailOpen = gosnowflake.OCSPFailOpenFalse
default:
return adbc.Error{
Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v),
Code: adbc.StatusInvalidArgument,
}
}
case OptionAuthToken:
d.cfg.Token = v
case OptionAuthOktaUrl:
d.cfg.OktaURL, err = url.Parse(v)
if err != nil {
return adbc.Error{
Msg: fmt.Sprintf("error parsing URL for database option '%s': '%s'", k, v),
Code: adbc.StatusInvalidArgument,
}
}
case OptionKeepSessionAlive:
switch v {
case adbc.OptionValueEnabled:
d.cfg.KeepSessionAlive = true
case adbc.OptionValueDisabled:
d.cfg.KeepSessionAlive = false
default:
return adbc.Error{
Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v),
Code: adbc.StatusInvalidArgument,
}
}
case OptionDisableTelemetry:
switch v {
case adbc.OptionValueEnabled:
d.cfg.DisableTelemetry = true
case adbc.OptionValueDisabled:
d.cfg.DisableTelemetry = false
default:
return adbc.Error{
Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v),
Code: adbc.StatusInvalidArgument,
}
}
case OptionJwtPrivateKey:
data, err := os.ReadFile(v)
if err != nil {
return adbc.Error{
Msg: "could not read private key file '" + v + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
d.cfg.PrivateKey, err = x509.ParsePKCS1PrivateKey(data)
if err != nil {
return adbc.Error{
Msg: "failed parsing private key file '" + v + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
case OptionClientRequestMFAToken:
switch v {
case adbc.OptionValueEnabled:
d.cfg.ClientRequestMfaToken = gosnowflake.ConfigBoolTrue
case adbc.OptionValueDisabled:
d.cfg.ClientRequestMfaToken = gosnowflake.ConfigBoolFalse
default:
return adbc.Error{
Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v),
Code: adbc.StatusInvalidArgument,
}
}
case OptionClientStoreTempCred:
switch v {
case adbc.OptionValueEnabled:
d.cfg.ClientStoreTemporaryCredential = gosnowflake.ConfigBoolTrue
case adbc.OptionValueDisabled:
d.cfg.ClientStoreTemporaryCredential = gosnowflake.ConfigBoolFalse
default:
return adbc.Error{
Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, v),
Code: adbc.StatusInvalidArgument,
}
}
case OptionLogTracing:
d.cfg.Tracing = v
default:
d.cfg.Params[k] = &v
}
}
return nil
}
func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
connector := gosnowflake.NewConnector(drv, *d.cfg)
ctx = gosnowflake.WithArrowAllocator(
gosnowflake.WithArrowBatches(ctx), d.alloc)
cn, err := connector.Connect(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
return &cnxn{cn: cn.(snowflakeConn), db: d, ctor: connector, sqldb: sql.OpenDB(connector)}, nil
}