go/config.go (340 lines of code) (raw):

// Copyright (c) 2022 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package athenadriver import ( "net/url" "regexp" "strings" "strconv" "time" ) // Config is for AWS Athena Driver Config. // Be noted this is different from aws.Config. type Config struct { dsn url.URL `yaml:"dns"` values url.Values `yaml:"values"` } var reSecretAccessKey = regexp.MustCompile(`secretAccessKey=[^&]+`) var reAccessID = regexp.MustCompile(`accessID=[^&]+`) var reSessionToken = regexp.MustCompile(`sessionToken=[^&]+`) var ( credAccessEnvKey = []string{ "AWS_ACCESS_KEY_ID", "AWS_ACCESS_KEY", } credSecretEnvKey = []string{ "AWS_SECRET_ACCESS_KEY", "AWS_SECRET_KEY", } credSessionEnvKey = []string{ "AWS_SESSION_TOKEN", } regionEnvKeys = []string{ "AWS_REGION", "AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set } stsRegionalEndpointKey = []string{ "AWS_STS_REGIONAL_ENDPOINTS", } ) // NewDefaultConfig is to new a Config with some default values. func NewDefaultConfig(outputBucket string, region string, accessID string, secretAccessKey string) (*Config, error) { conf := NewNoOpsConfig() err := conf.SetOutputBucket(outputBucket) if err != nil { return nil, err } err = conf.SetRegion(region) if err != nil { return nil, err } err = conf.SetAccessID(accessID) if err != nil { return nil, err } err = conf.SetSecretAccessKey(secretAccessKey) conf.SetResultPollIntervalSeconds(PoolInterval) return conf, err } // NewNoOpsConfig is to create a noop version of driver Config WITHOUT credentials. func NewNoOpsConfig() *Config { a := Config{ dsn: url.URL{}, } a.dsn.Scheme = "s3" a.values = make(map[string][]string, 32) a.values.Set("db", DefaultDBName) a.values.Set("region", DefaultRegion) a.SetMissingAsEmptyString(true) a.SetWGRemoteCreationAllowed(true) return &a } // NewConfig is to create Config from a string. func NewConfig(s string) (*Config, error) { u, err := url.Parse(s) if err != nil { return nil, err } a := Config{ dsn: *u, } a.values, err = url.ParseQuery(u.RawQuery) if !a.isValid() { return nil, ErrConfigInvalidConfig } return &a, err } func (c *Config) isValid() bool { return c.dsn.Scheme == "s3" && c.values.Get("region") != "" } // String is to return the string form of DSN. func (c *Config) String() string { return c.dsn.String() } // Stringify is to return the string form of DSN like JSON.stringify(). // Please refer to: https://www.w3schools.com/js/js_json_stringify.asp func (c *Config) Stringify() string { c.dsn.RawQuery = c.values.Encode() return c.String() } // SafeStringify is a secure version of Stringify(), with security information masked with *. func (c *Config) SafeStringify() string { rawString := c.Stringify() s := reSecretAccessKey.ReplaceAllString(rawString, `secretAccessKey=*`) s = reAccessID.ReplaceAllString(s, `accessID=*`) s = reSessionToken.ReplaceAllString(s, `sessionToken=*`) return s } // SetOutputBucket is to set S3 bucket for result set. // On March 1, 2018, we updated our naming conventions for S3 buckets in the US East (N. Virginia) Region to match // the naming conventions that we use in all other worldwide AWS Regions. // Amazon S3 no longer supports creating bucket names that contain uppercase letters or underscores. // https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html#bucketnamingrules func (c *Config) SetOutputBucket(o string) error { if !strings.HasPrefix(o, "s3://") { return ErrConfigOutputLocation } o = o[5:] ss := strings.SplitN(o, "/", 2) if len(ss) == 2 { c.dsn.Host = ss[0] c.dsn.Path = ss[1] } else { c.dsn.Host = ss[0] c.dsn.Path = "" } return nil } // SetRegion is to set region. func (c *Config) SetRegion(o string) error { if len(o) == 0 { return ErrConfigRegion } c.values.Set("region", o) return nil } // GetRegion is getter of Region. func (c *Config) GetRegion() string { if val := c.values.Get("region"); val != "" { return val } return GetFromEnvVal(regionEnvKeys) } // SetUser is a setter of User. func (c *Config) SetUser(o string) { c.dsn.User = url.UserPassword(o, "") } // SetDB is a setter of DB. func (c *Config) SetDB(o string) { c.values.Set("db", o) } // GetDB is getter of DB. func (c *Config) GetDB() string { if val := c.values.Get("db"); val != "" { return val } return DefaultDBName } // SetResultPollIntervalSeconds is a setter of Overriding poll interval. func (c *Config) SetResultPollIntervalSeconds(n int) { c.values.Set("resultPollIntervalSeconds", strconv.Itoa(n)) } // GetResultPollIntervalSeconds is getter of resultPollIntervalSeconds. func (c *Config) GetResultPollIntervalSeconds() time.Duration { if val := c.values.Get("resultPollIntervalSeconds"); val != "" { n, err := strconv.Atoi(val) if err != nil { return time.Duration(PoolInterval) * time.Second } return time.Duration(n) * time.Second } return time.Duration(PoolInterval) * time.Second } // SetWorkGroup is a setter of WorkGroup. func (c *Config) SetWorkGroup(w *Workgroup) error { if w == nil { return ErrConfigWGPointer } c.values.Set("workgroupName", w.Name) if w.Tags != nil { tagsString := c.values.Get("tag") for _, tag := range w.Tags.Get() { tagsString += "|" + *tag.Key + "`" + *tag.Value } c.values.Set("tag", tagsString) } if w.Config == nil { w.Config = GetDefaultWGConfig() } c.values.Set("workgroupConfig", w.Config.String()) return nil } // SetAccessID is a setter of AWS Access ID. func (c *Config) SetAccessID(o string) error { if len(o) == 0 { return ErrConfigAccessIDRequired } c.values.Set("accessID", o) return nil } // GetAccessID is a getter of AWS Access ID. It will try to get access ID from: // 1. string stored in c.values // 2. environmental variable ${AWS_ACCESS_KEY_ID} or ${AWS_ACCESS_KEY} func (c *Config) GetAccessID() string { if val := c.values.Get("accessID"); val != "" { return val } return GetFromEnvVal(credAccessEnvKey) } // SetSecretAccessKey is a setter of AWS Access Key. func (c *Config) SetSecretAccessKey(o string) error { if len(o) == 0 { return ErrConfigAccessKeyRequired } c.values.Set("secretAccessKey", o) return nil } // GetSecretAccessKey is a getter of AWS Access Key. func (c *Config) GetSecretAccessKey() string { if val := c.values.Get("secretAccessKey"); val != "" { return val } return GetFromEnvVal(credSecretEnvKey) } // SetSessionToken is a setter of AWS Session Token. func (c *Config) SetSessionToken(o string) { c.values.Set("sessionToken", o) } // GetSessionToken is a getter of AWS Session Token. func (c *Config) GetSessionToken() string { if val := c.values.Get("sessionToken"); val != "" { return val } return GetFromEnvVal(credSessionEnvKey) } // GetUser is getter of User. func (c *Config) GetUser() string { return c.dsn.User.Username() } // GetOutputBucket is getter of OutputBucket. func (c *Config) GetOutputBucket() string { if strings.HasPrefix(c.dsn.Path, "/") { return c.dsn.Scheme + "://" + c.dsn.Host + c.dsn.Path } return c.dsn.Scheme + "://" + c.dsn.Host + "/" + c.dsn.Path } // GetWorkgroup is getter of Workgroup. func (c *Config) GetWorkgroup() Workgroup { tagString := c.values.Get("tag") if len(tagString) == 0 { wg := Workgroup{ Name: c.values.Get("workgroupName"), Config: GetDefaultWGConfig(), Tags: NewWGTags(), } return wg } tags := strings.Split(tagString[1:], "|") t := NewWGTags() for _, tag := range tags { ts := strings.Split(tag, "`") t.AddTag(ts[0], ts[1]) } wg := Workgroup{ Name: c.values.Get("workgroupName"), Config: GetDefaultWGConfig(), Tags: t, } return wg } // IsMissingAsEmptyString return true if missing value is set to be returned as empty string. func (c *Config) IsMissingAsEmptyString() bool { return c.values.Get("missingAsEmptyString") == "true" } // IsMissingAsDefault return true if missing value is set to be returned as default data. func (c *Config) IsMissingAsDefault() bool { return c.values.Get("missingAsDefault") == "true" } // IsMissingAsNil return true if missing value is set to be returned as nil. func (c *Config) IsMissingAsNil() bool { return c.values.Get("missingAsNil") == "true" } // SetMissingAsEmptyString is to set if missing value is returned as empty string. func (c *Config) SetMissingAsEmptyString(b bool) { missingAsEmptyString := "true" if !b { missingAsEmptyString = "false" } c.values.Set("missingAsEmptyString", missingAsEmptyString) } // SetMissingAsDefault is to set if missing value is returned as default data. func (c *Config) SetMissingAsDefault(b bool) { if b { c.values.Set("missingAsDefault", "true") } else { c.values.Set("missingAsDefault", "false") } } // SetMissingAsNil is to set if missing value is returned as nil. func (c *Config) SetMissingAsNil(b bool) { if b { c.values.Set("missingAsNil", "true") } else { c.values.Set("missingAsNil", "false") } } // CheckColumnMasked is to check if a specific column has been masked by some value. // https://stackoverflow.com/questions/30285169/replace-the-empty-or-null-value-with-specific-value-in-hive-query-result/30289503 func (c *Config) CheckColumnMasked(columnName string) (string, bool) { if val, ok := c.values["masked_"+columnName]; ok { return val[0], true } return "", false } // SetMaskedColumnValue is to set masked value for some column. func (c *Config) SetMaskedColumnValue(columnName string, value string) { c.values.Set("masked_"+columnName, value) } // IsWGRemoteCreationAllowed is to check if we are allowed to create workgroup with API from client. func (c *Config) IsWGRemoteCreationAllowed() bool { return c.values.Get("WGRemoteCreation") == "true" } // SetWGRemoteCreationAllowed is to set if we are allowed to create workgroup with API from client. func (c *Config) SetWGRemoteCreationAllowed(b bool) { if b { c.values.Set("WGRemoteCreation", "true") } else { c.values.Set("WGRemoteCreation", "false") } } // IsLoggingEnabled is to check if driver level logging enabled. func (c *Config) IsLoggingEnabled() bool { return c.values.Get("LoggingEnabled") != "false" } // SetLogging is to set if driver level logging enabled. func (c *Config) SetLogging(b bool) { if b { c.values.Set("LoggingEnabled", "true") } else { c.values.Set("LoggingEnabled", "false") } } // IsMetricsEnabled is to check if driver level metrics enabled. func (c *Config) IsMetricsEnabled() bool { return c.values.Get("MetricsEnabled") == "true" } // SetMetrics is to set if driver level logging enabled. func (c *Config) SetMetrics(b bool) { if b { c.values.Set("MetricsEnabled", "true") } else { c.values.Set("MetricsEnabled", "false") } } // SetReadOnly is to set if only SELECT/SHOW/DESC are allowed func (c *Config) SetReadOnly(b bool) { if b { c.values.Set("ReadOnly", "true") } else { c.values.Set("ReadOnly", "false") } } // IsReadOnly is to check if only SELECT/SHOW/DESC are allowed func (c *Config) IsReadOnly() bool { return c.values.Get("ReadOnly") == "true" } // SetMoneyWise is to set if we are in the moneywise mode func (c *Config) SetMoneyWise(b bool) { if b { c.values.Set("MoneyWise", "true") } else { c.values.Set("MoneyWise", "false") } } // IsMoneyWise is to check if we are in the moneywise mode func (c *Config) IsMoneyWise() bool { return c.values.Get("MoneyWise") == "true" } // SetAWSProfile is to manually set the credential provider // https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html func (c *Config) SetAWSProfile(profile string) { c.values.Set("AWSProfile", profile) } // GetAWSProfile is to get the credential provider name manually set by user func (c *Config) GetAWSProfile() string { return c.values.Get("AWSProfile") } // SetServiceLimitOverride is to set values from a ServiceLimitOverride func (c *Config) SetServiceLimitOverride(serviceLimitOverride ServiceLimitOverride) { for k, v := range serviceLimitOverride.GetAsStringMap() { c.values.Set(k, v) } } // GetServiceLimitOverride is to get the ServiceLimitOverride manually set by a user func (c *Config) GetServiceLimitOverride() *ServiceLimitOverride { serviceLimitOverride := NewServiceLimitOverride() serviceLimitOverride.SetFromValues(c.values) return serviceLimitOverride }