lib/configfx/module.go (180 lines of code) (raw):
// Copyright (c) 2022 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package configfx
import (
"context"
"flag"
"fmt"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"strings"
secret "github.com/uber/athenadriver/examples/constants"
drv "github.com/uber/athenadriver/go"
"go.uber.org/config"
"go.uber.org/fx"
)
// Module is to provide dependency of Configuration to main app
var Module = fx.Provide(new)
// Params defines the dependencies or inputs
type Params struct {
fx.In
LC fx.Lifecycle
}
// ReaderOutputConfig is to represent the output section of configuration file
type ReaderOutputConfig struct {
// Render is for the output format
Render string `yaml:"render"`
// Page is for the pagination
Page int `yaml:"pagesize"`
// Style is output style
Style string `yaml:"style"`
// Rowonly is for displaying header or not
Rowonly bool `yaml:"rowonly"`
// Moneywise is for displaying spending or not
Moneywise bool `yaml:"moneywise"`
// Fastfail is for multiple queries
Fastfail bool `yaml:"fastfail"`
}
// ReaderInputConfig is to represent the input section of configuration file
type ReaderInputConfig struct {
// Bucket is the output bucket
Bucket string `yaml:"bucket"`
// Region is AWS region
Region string `yaml:"region"`
// Database is the name of the DB
Database string `yaml:"database"`
// Admin is for write mode
Admin bool `yaml:"admin"`
}
// AthenaDriverConfig is Athena Driver Configuration
type AthenaDriverConfig struct {
// OutputConfig is for the output section of the config
OutputConfig ReaderOutputConfig
// InputConfig is for the input section of the config
InputConfig ReaderInputConfig
// QueryString is the query string
QueryString []string
// DrvConfig is the datastructure of Driver Config
DrvConfig *drv.Config
}
// Result defines output
type Result struct {
fx.Out
// MyConfig is the current AthenaDriver Config
MyConfig AthenaDriverConfig
}
func init() {
setUpFlagUsage(context.Background())
}
func new(p Params) (Result, error) {
var mc = AthenaDriverConfig{
QueryString: make([]string, 0),
}
var (
provider *config.YAML
err error
)
p.LC.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return nil
},
OnStop: func(ctx context.Context) error {
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
return nil
},
})
var bucket = flag.String("b", secret.OutputBucket, "Athena resultset output bucket")
var database = flag.String("d", "default", "The database you want to query")
var query = flag.String("q", "select 1", "The SQL query string or a file containing SQL string")
var rowOnly = flag.Bool("r", false, "Display rows only, don't show the first row as columninfo")
var moneyWise = flag.Bool("m", false, "Enable moneywise mode to display the query cost as the first line of the output")
var versionFlag = flag.Bool("v", false, "Print the current version and exit")
var admin = flag.Bool("a", false, "Enable admin mode, so database write(create/drop) is allowed at athenadriver level")
var style = flag.String("y", "default", "Output rendering style")
var format = flag.String("o", "csv", "Output format(options: table, markdown, csv, html)")
var fastFail = flag.Bool("f", true, "fast fail when where are multiple queries")
flag.Parse()
switch {
case *versionFlag:
println("Current build version: v" + drv.DriverVersion)
os.Exit(0)
return Result{}, fmt.Errorf("no")
}
// How to install a config file from a library
if _, err = os.Stat(homeDir() + "/athenareader.config"); err == nil {
provider, err = config.NewYAML(config.File(homeDir() + "/athenareader.config"))
} else if _, err = os.Stat("athenareader.config"); err == nil {
provider, err = config.NewYAML(config.File("athenareader.config"))
} else {
goPath := os.Getenv("GOPATH")
if goPath == "" {
goPath = homeDir() + "/go"
if _, err = os.Stat(goPath); err != nil {
d, _ := os.Getwd()
println("could not find athenareader.config in home directory or current directory " + d)
os.Exit(1)
}
}
path := goPath + "/src/github.com/uber/athenadriver/athenareader/athenareader.config"
if _, err = os.Stat(path); err == nil {
copyFile(path, homeDir()+"/athenareader.config")
provider, err = config.NewYAML(config.File(path))
} else {
err = downloadFile(homeDir()+"/athenareader.config",
"https://raw.githubusercontent.com/uber/athenadriver/master/athenareader/athenareader.config")
if err != nil {
d, _ := os.Getwd()
println("could not find athenareader.config in home directory or current directory " + d)
os.Exit(1)
} else {
provider, err = config.NewYAML(config.File(homeDir() + "/athenareader.config"))
}
}
}
if err != nil {
return Result{}, err
}
provider.Get("athenareader.output").Populate(&mc.OutputConfig)
provider.Get("athenareader.input").Populate(&mc.InputConfig)
filePath := expand(*query)
if _, err := os.Stat(filePath); err == nil {
b, err := ioutil.ReadFile(filePath)
if err == nil {
mc.QueryString = strings.Split(string(b), "\n\n") // convert content to a '[]string'
}
} else {
mc.QueryString = append(mc.QueryString, *query)
}
mc.DrvConfig, err = drv.NewDefaultConfig(mc.InputConfig.Bucket, mc.InputConfig.Region, secret.AccessID, secret.SecretAccessKey)
if err != nil {
return Result{}, err
}
if isFlagPassed("b") {
mc.InputConfig.Bucket = *bucket
mc.DrvConfig.SetOutputBucket(mc.InputConfig.Bucket)
}
if isFlagPassed("d") {
mc.InputConfig.Database = *database
}
if isFlagPassed("r") {
mc.OutputConfig.Rowonly = *rowOnly
}
if isFlagPassed("m") {
mc.OutputConfig.Moneywise = *moneyWise
}
if isFlagPassed("f") {
mc.OutputConfig.Fastfail = *fastFail
} else {
mc.OutputConfig.Fastfail = true
}
if isFlagPassed("a") {
mc.InputConfig.Admin = *admin
}
if isFlagPassed("y") {
mc.OutputConfig.Style = *style
}
if isFlagPassed("o") {
mc.OutputConfig.Render = *format
}
if mc.OutputConfig.Moneywise {
mc.DrvConfig.SetMoneyWise(true)
}
mc.DrvConfig.SetDB(mc.InputConfig.Database)
if !mc.InputConfig.Admin {
mc.DrvConfig.SetReadOnly(true)
}
if err != nil {
return Result{}, err
}
return Result{
MyConfig: mc,
}, nil
}
func expand(path string) string {
if len(path) == 0 || path[0] != '~' {
return path
}
usr, err := user.Current()
if err != nil {
return "/tmp/"
}
return filepath.Join(usr.HomeDir, path[1:])
}