arrow/flight/flightsql/driver/config.go (156 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package driver import ( "crypto/tls" "fmt" "net/url" "sync" "time" "github.com/google/uuid" ) // TLS configuration registry var ( tlsConfigRegistry = map[string]*tls.Config{ "skip-verify": {InsecureSkipVerify: true}, } tlsRegistryMutex sync.Mutex ) func RegisterTLSConfig(name string, cfg *tls.Config) error { tlsRegistryMutex.Lock() defer tlsRegistryMutex.Unlock() // Prevent name collisions if _, found := tlsConfigRegistry[name]; found { return ErrRegistryEntryExists } tlsConfigRegistry[name] = cfg return nil } func UnregisterTLSConfig(name string) error { tlsRegistryMutex.Lock() defer tlsRegistryMutex.Unlock() if _, found := tlsConfigRegistry[name]; !found { return ErrRegistryNoEntry } delete(tlsConfigRegistry, name) return nil } func GetTLSConfig(name string) (*tls.Config, bool) { tlsRegistryMutex.Lock() defer tlsRegistryMutex.Unlock() cfg, found := tlsConfigRegistry[name] return cfg, found } type DriverConfig struct { Address string Username string Password string Token string Timeout time.Duration Params map[string]string TLSEnabled bool TLSConfigName string TLSConfig *tls.Config } func NewDriverConfigFromDSN(dsn string) (*DriverConfig, error) { u, err := url.Parse(dsn) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } // Sanity checks on the given connection string if u.Scheme != "flightsql" { return nil, fmt.Errorf("invalid scheme %q", u.Scheme) } if u.Path != "" { return nil, fmt.Errorf("unexpected path %q", u.Path) } // Extract the settings var username, password string if u.User != nil { username = u.User.Username() if v, set := u.User.Password(); set { password = v } } config := &DriverConfig{ Address: u.Host, Username: username, Password: password, Params: make(map[string]string), } // Determine the parameters for key, values := range u.Query() { // We only support single instances if len(values) > 1 { return nil, fmt.Errorf("too many values for %q", key) } var v string if len(values) > 0 { v = values[0] } switch key { case "token": config.Token = v case "timeout": config.Timeout, err = time.ParseDuration(v) if err != nil { return nil, err } case "tls": switch v { case "true", "enabled": config.TLSEnabled = true case "false", "disabled": config.TLSEnabled = false default: config.TLSEnabled = true config.TLSConfigName = v cfg, found := GetTLSConfig(config.TLSConfigName) if !found { return nil, fmt.Errorf("%q TLS %w", config.TLSConfigName, ErrRegistryNoEntry) } config.TLSConfig = cfg } default: config.Params[key] = v } } return config, nil } func (config *DriverConfig) DSN() string { u := url.URL{ Scheme: "flightsql", Host: config.Address, } if config.Username != "" { if config.Password == "" { u.User = url.User(config.Username) } else { u.User = url.UserPassword(config.Username, config.Password) } } // Set the parameters values := url.Values{} if config.Token != "" { values.Add("token", config.Token) } if config.Timeout > 0 { values.Add("timeout", config.Timeout.String()) } if config.TLSEnabled { switch config.TLSConfigName { case "skip-verify": values.Add("tls", "skip-verify") case "": // Use system defaults if no config is given if config.TLSConfig == nil { values.Add("tls", "enabled") break } // We got a custom TLS configuration but no name, create a unique one config.TLSConfigName = uuid.NewString() fallthrough default: values.Add("tls", config.TLSConfigName) if config.TLSConfig != nil { // Ignore the returned error as we do not care if the config // was registered before. If this fails and the config is not // yet registered, the driver will error out when parsing the // DSN. _ = RegisterTLSConfig(config.TLSConfigName, config.TLSConfig) } } } for k, v := range config.Params { values.Add(k, v) } // Check if we do have parameters at all and set them if len(values) > 0 { u.RawQuery = values.Encode() } return u.String() }