go/utils.go (664 lines of code) (raw):

// Copyright (c) 2022 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package athenadriver import ( "bytes" "database/sql" "database/sql/driver" "encoding/csv" "fmt" "math" "math/rand" "os" "regexp" "strconv" "strings" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/aws/aws-sdk-go/service/athena" "github.com/jedib0t/go-pretty/v6/table" "github.com/xwb1989/sqlparser" ) // OutputStyles are all the styles we can choose to print query result var OutputStyles = [...]string{"StyleDefault", "StyleBold", "StyleColoredBright", "StyleColoredDark", "StyleColoredBlackOnBlueWhite", "StyleColoredBlackOnCyanWhite", "StyleColoredBlackOnGreenWhite", "StyleColoredBlackOnMagentaWhite", "StyleColoredBlackOnYellowWhite", "StyleColoredBlackOnRedWhite", "StyleColoredBlueWhiteOnBlack", "StyleColoredCyanWhiteOnBlack", "StyleColoredGreenWhiteOnBlack", "StyleColoredMagentaWhiteOnBlack", "StyleColoredRedWhiteOnBlack", "StyleColoredYellowWhiteOnBlack", "StyleDouble", "StyleLight", "StyleRounded", } // OutputFormats are all the formats we can choose to print query result var OutputFormats = [...]string{"csv", "html", "markdown", "table"} func scanNullString(v interface{}) (sql.NullString, error) { if v == nil { return sql.NullString{}, nil } vv, ok := v.(string) if !ok { return sql.NullString{}, fmt.Errorf("cannot convert %v (%T) to string", v, v) } return sql.NullString{Valid: true, String: vv}, nil } func mockRowsToSQLRows(mockRows *sqlmock.Rows) *sql.Rows { db, mock, _ := sqlmock.New() mock.ExpectQuery("SELECT_OK").WillReturnRows(mockRows) rows, _ := db.Query("SELECT_OK") return rows } // ColsToCSV is a convenient function to convert columns of sql.Rows to CSV format. func ColsToCSV(rows *sql.Rows) string { if rows == nil { return "" } columns, _ := rows.Columns() s := "" for i, v := range columns { s += v if i != len(columns)-1 { s += "," } else { s += "\n" } } return s } // RowsToCSV is to convert rows of sql.Rows to CSV format. func RowsToCSV(rows *sql.Rows) string { if rows == nil { return "" } columns, _ := rows.Columns() var buf bytes.Buffer csvWriter := csv.NewWriter(&buf) records := make([][]string, 0) for rows.Next() { rawResult := make([][]byte, len(columns)) row := make([]interface{}, len(columns)) for i := range rawResult { row[i] = &rawResult[i] // pointers to each string in the interface slice } // We don't consider malformed rows _ = rows.Scan(row...) s := make([]string, len(columns)) for i, cell := range rawResult { s[i] = string(cell) } records = append(records, s) } csvWriter.WriteAll(records) return buf.String() } // ColsRowsToCSV is a convenient function to convert columns and rows of sql.Rows to CSV format. func ColsRowsToCSV(rows *sql.Rows) string { s := ColsToCSV(rows) r := RowsToCSV(rows) return s + r } func getTableStyle(style string) table.Style { switch style { case "StyleColoredBright": return table.StyleColoredBright case "StyleBold": return table.StyleBold case "StyleColoredDark": return table.StyleColoredDark case "StyleColoredBlackOnBlueWhite": return table.StyleColoredBlackOnBlueWhite case "StyleColoredBlackOnCyanWhite": return table.StyleColoredBlackOnCyanWhite case "StyleColoredBlackOnGreenWhite": return table.StyleColoredBlackOnGreenWhite case "StyleColoredBlackOnMagentaWhite": return table.StyleColoredBlackOnMagentaWhite case "StyleColoredBlackOnYellowWhite": return table.StyleColoredBlackOnYellowWhite case "StyleColoredBlackOnRedWhite": return table.StyleColoredBlackOnRedWhite case "StyleColoredBlueWhiteOnBlack": return table.StyleColoredBlueWhiteOnBlack case "StyleColoredCyanWhiteOnBlack": return table.StyleColoredCyanWhiteOnBlack case "StyleColoredGreenWhiteOnBlack": return table.StyleColoredGreenWhiteOnBlack case "StyleColoredMagentaWhiteOnBlack": return table.StyleColoredMagentaWhiteOnBlack case "StyleColoredRedWhiteOnBlack": return table.StyleColoredRedWhiteOnBlack case "StyleColoredYellowWhiteOnBlack": return table.StyleColoredYellowWhiteOnBlack case "StyleDouble": return table.StyleDouble case "StyleLight": return table.StyleLight case "StyleRounded": return table.StyleRounded } return table.StyleDefault } func renderTable(renderType string, w table.Writer) string { switch renderType { case "markdown": return w.RenderMarkdown() case "table": return w.Render() case "html": return w.RenderHTML() } return w.RenderCSV() } // PrettyPrintSQLRows is to print rows beautifully func PrettyPrintSQLRows(rows *sql.Rows, style string, render string, page int) { t := table.NewWriter() t.SetOutputMirror(os.Stdout) if rows == nil { return } columns, _ := rows.Columns() for rows.Next() { rawResult := make([][]byte, len(columns)) row := make([]interface{}, len(columns)) for i := range rawResult { row[i] = &rawResult[i] // pointers to each string in the interface slice } // We don't consider malformed rows _ = rows.Scan(row...) s := make(table.Row, len(columns)) for i, cell := range rawResult { s[i] = string(cell) } t.AppendRow(s) } t.SetPageSize(page) t.SetStyle(getTableStyle(style)) renderTable(render, t) } // PrettyPrintSQLColsRows is to print rows beautifully with header func PrettyPrintSQLColsRows(rows *sql.Rows, style string, render string, page int) { t := table.NewWriter() t.SetOutputMirror(os.Stdout) if rows == nil { return } columns, _ := rows.Columns() if columns != nil && len(columns) > 0 { myrow := make(table.Row, len(columns)) for i, c := range columns { myrow[i] = c } t.AppendHeader(myrow) } for rows.Next() { rawResult := make([][]byte, len(columns)) row := make([]interface{}, len(columns)) for i := range rawResult { row[i] = &rawResult[i] // pointers to each string in the interface slice } // We don't consider malformed rows _ = rows.Scan(row...) s := make(table.Row, len(columns)) for i, cell := range rawResult { s[i] = string(cell) } t.AppendRow(s) } t.SetPageSize(page) t.SetStyle(getTableStyle(style)) renderTable(render, t) } // PrettyPrintCSV is to print rows in CSV format with default style func PrettyPrintCSV(rows *sql.Rows) { PrettyPrintSQLColsRows(rows, "StyleDefault", "csv", 1024) } // PrettyPrintMD is to print rows in markdown format with default style func PrettyPrintMD(rows *sql.Rows) { PrettyPrintSQLColsRows(rows, "StyleDefault", "markdown", 1024) } // PrettyPrintFancy is to print rows in table format with fancy style func PrettyPrintFancy(rows *sql.Rows) { PrettyPrintSQLColsRows(rows, "StyleColoredGreenWhiteOnBlack", "table", 1024) } // colInFirstPage is to check if this is a SELECT or VALUES statement. // Some Sample Queries are like: // // USING FUNCTION predict_customer_registration(age INTEGER) // // RETURNS DOUBLE TYPE // SAGEMAKER_INVOKE_ENDPOINT WITH (sagemaker_endpoint = 'xgboost-2019-09-20-04-49-29-303') // // SELECT predict_customer_registration(age) AS probability_of_enrolling, customer_id // // FROM "sampledb"."ml_test_dataset" // WHERE predict_customer_registration(age) < 0.5; // // USING FUNCTION decompress(col1 VARCHAR) // // RETURNS VARCHAR TYPE // LAMBDA_INVOKE WITH (lambda_name = 'MyAthenaUDFLambda') // // SELECT // // decompress('ewLLinKzEsPyXdKdc7PLShKLS5OTQEAUrEH9w=='); // // WITH // dataset AS ( // // SELECT // ARRAY ['hello', 'amazon', 'athena'] AS words, // ARRAY ['hi', 'alexa'] AS alexa // // ) // SELECT concat(words, alexa) AS welcome_msg FROM dataset func colInFirstPage(query string) bool { nQuery := strings.TrimSpace(strings.ToLower(query)) return strings.Index(nQuery, "select") == 0 || strings.Index(nQuery, "using") == 0 || strings.Index(nQuery, "with") == 0 || strings.Index(nQuery, "values") == 0 } func isReadOnlyStatement(query string) bool { nQuery := strings.TrimSpace(strings.ToLower(query)) return strings.Index(nQuery, "select") == 0 || strings.Index(nQuery, "using") == 0 || strings.Index(nQuery, "with") == 0 || strings.Index(nQuery, "desc") == 0 || strings.Index(nQuery, "show") == 0 || IsQID(query) } func isInsertStatement(query string) bool { nQuery := strings.TrimSpace(strings.ToLower(query)) return strings.Index(nQuery, "insert") == 0 } func newColumnInfo(colName string, colType interface{}) *athena.ColumnInfo { caseSensitive := false catalogName := "hive" nullable := "UNKNOWN" precision := int64(19) scale := int64(0) schemaName := "" tableName := "" if colType == nil { return &athena.ColumnInfo{ CaseSensitive: &caseSensitive, CatalogName: &catalogName, Label: &colName, Name: &colName, Nullable: &nullable, Precision: &precision, Scale: &scale, SchemaName: &schemaName, TableName: &tableName, Type: nil, } } ct := colType.(string) return &athena.ColumnInfo{ CaseSensitive: &caseSensitive, CatalogName: &catalogName, Label: &colName, Name: &colName, Nullable: &nullable, Precision: &precision, Scale: &scale, SchemaName: &schemaName, TableName: &tableName, Type: &ct, } } func newRow(colLen int, rData []string) *athena.Row { var nData = make([]*athena.Datum, colLen) for i := 0; i < colLen; i++ { nData[i] = &athena.Datum{VarCharValue: &rData[i]} } return &athena.Row{ Data: nData, } } func randString(l int) string { const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" s := make([]byte, l) for i := 0; i < l; i++ { s[i] = alphabet[rand.Intn(len(alphabet))] } return string(s) } func randomInt64(min int64, max int64) int64 { return min + rand.Int63n(max-min) } // https://golang.org/ref/spec#Numeric_types func randInt8() *string { s := strconv.Itoa(int(randomInt64(math.MinInt8, math.MaxInt8))) return &s } func randInt16() *string { s := strconv.Itoa(int(randomInt64(math.MinInt16, math.MaxInt16))) return &s } func randInt() *string { s := strconv.Itoa(int(randomInt64(math.MinInt32, math.MaxInt32))) return &s } func randUInt64() *string { s := strconv.FormatUint(rand.Uint64(), 10) return &s } func randFloat32() *string { s := strconv.FormatFloat(rand.Float64(), 'f', 6, 32) return &s } func randFloat64() *string { s := strconv.FormatFloat(rand.Float64(), 'f', 6, 64) return &s } func randStr() *string { s := randString(rand.Intn(10)) return &s } func randBool() *string { if rand.Intn(10)%2 == 0 { s := "true" return &s } s := "false" return &s } func randDate() *string { min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() delta := max - min sec := rand.Int63n(delta) + min s := time.Unix(sec, 0).Format(DateUniXFormat) return &s } func randTimeStamp() *string { min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() delta := max - min sec := rand.Int63n(delta) + min s := time.Unix(sec, 0).Format(TimestampUniXFormat) return &s } func genHeaderRow(columns []*athena.ColumnInfo) *athena.Row { colLen := len(columns) rData := make([]string, colLen) for i := 0; i < colLen; i++ { rData[i] = *columns[i].Name } return newRow(colLen, rData) } // randRow generates a row with random data aligned with type information in // athena.ColumnInfo func randRow(columns []*athena.ColumnInfo) *athena.Row { colLen := len(columns) row := &athena.Row{ Data: make([]*athena.Datum, colLen), } for j := 0; j < colLen; j++ { if columns[j].Type == nil { s := "a\tb" row.Data[j] = &athena.Datum{VarCharValue: &s} continue } switch *columns[j].Type { case "tinyint": row.Data[j] = &athena.Datum{VarCharValue: randInt8()} case "smallint": row.Data[j] = &athena.Datum{VarCharValue: randInt16()} case "integer": row.Data[j] = &athena.Datum{VarCharValue: randInt()} case "bigint": row.Data[j] = &athena.Datum{VarCharValue: randUInt64()} case "float", "real": row.Data[j] = &athena.Datum{VarCharValue: randFloat32()} case "double": row.Data[j] = &athena.Datum{VarCharValue: randFloat64()} case "json", "char", "varchar", "varbinary", "row", "string", "binary", "struct", "interval year to month", "interval day to second", "decimal", "ipaddress", "array", "map", "unknown": row.Data[j] = &athena.Datum{VarCharValue: randStr()} case "boolean": row.Data[j] = &athena.Datum{VarCharValue: randBool()} case "date": row.Data[j] = &athena.Datum{VarCharValue: randDate()} case "time", "time with time zone", "timestamp with time zone": row.Data[j] = &athena.Datum{VarCharValue: randTimeStamp()} case "timestamp": row.Data[j] = &athena.Datum{VarCharValue: randTimeStamp()} default: row.Data[j] = &athena.Datum{VarCharValue: randStr()} } } return row } func missingDataRow(columns []*athena.ColumnInfo) *athena.Row { colLen := len(columns) row := &athena.Row{ Data: make([]*athena.Datum, colLen), } for j := 0; j < colLen; j++ { switch *columns[j].Type { case "integer": row.Data[j] = &athena.Datum{VarCharValue: nil} default: row.Data[j] = nil } } return row } func genRow(rowData []*string) *athena.Row { row := &athena.Row{ Data: make([]*athena.Datum, len(rowData)), } for i := 0; i < len(rowData); i++ { row.Data[i] = &athena.Datum{VarCharValue: rowData[i]} } return row } // columnTypes must be from one of AthenaColumnTypes func newHeaderResultPage(columnNames []*string, columnTypes []string, rowsData [][]*string) *athena.GetQueryResultsOutput { columns := make([]*athena.ColumnInfo, len(columnNames)) for i := 0; i < len(columnNames); i++ { columns[i] = newColumnInfo(*columnNames[i], columnTypes[i]) } rowLen := len(rowsData) rows := make([]*athena.Row, rowLen+1) rows[0] = genHeaderRow(columns) for i := 1; i < rowLen+1; i++ { rows[i] = genRow(rowsData[i-1]) } return &athena.GetQueryResultsOutput{ NextToken: nil, ResultSet: &athena.ResultSet{ ResultSetMetadata: &athena.ResultSetMetadata{ ColumnInfo: columns, }, Rows: rows, }, } } func newHeaderlessResultPage(columnNames []*string, columnTypes []string, rowsData [][]*string) *athena.GetQueryResultsOutput { columns := make([]*athena.ColumnInfo, len(columnNames)) for i := 0; i < len(columnNames); i++ { columns[i] = newColumnInfo(*columnNames[i], columnTypes[i]) } rowLen := len(rowsData) rows := make([]*athena.Row, rowLen) for i := 0; i < rowLen; i++ { rows[i] = genRow(rowsData[i]) } return &athena.GetQueryResultsOutput{ NextToken: nil, ResultSet: &athena.ResultSet{ ResultSetMetadata: &athena.ResultSetMetadata{ ColumnInfo: columns, }, Rows: rows, }, } } func newRandomHeaderResultPage(columns []*athena.ColumnInfo, nextToken *string, rowLen int) *athena.GetQueryResultsOutput { rows := make([]*athena.Row, rowLen) rows[0] = genHeaderRow(columns) for i := 1; i < rowLen; i++ { rows[i] = randRow(columns) } return &athena.GetQueryResultsOutput{ NextToken: nextToken, ResultSet: &athena.ResultSet{ ResultSetMetadata: &athena.ResultSetMetadata{ ColumnInfo: columns, }, Rows: rows, }, } } func newRandomHeaderlessResultPage(columns []*athena.ColumnInfo, nextToken *string, rowLen int) *athena.GetQueryResultsOutput { rows := make([]*athena.Row, rowLen) for i := 0; i < rowLen; i++ { rows[i] = randRow(columns) } return &athena.GetQueryResultsOutput{ NextToken: nextToken, ResultSet: &athena.ResultSet{ ResultSetMetadata: &athena.ResultSetMetadata{ ColumnInfo: columns, }, Rows: rows, }, } } // escapeBytesBackslash escapes []byte with backslashes (\) // This escapes the contents of a string (provided as []byte) by adding backslashes before special // characters, and turning others into specific escape sequences, such as // turning newlines into \n and null bytes into \0. // // \xNN notation to define a string constant holding some peculiar byte values. // (Of course, bytes range from hexadecimal values 00 through FF, inclusive.) func escapeBytesBackslash(buf, v []byte) []byte { pos := len(buf) buf = reserveBuffer(buf, len(v)*2) for _, c := range v { switch c { case '\x00': buf[pos] = '\\' buf[pos+1] = '0' pos += 2 case '\n': buf[pos] = '\\' buf[pos+1] = 'n' pos += 2 case '\r': buf[pos] = '\\' buf[pos+1] = 'r' pos += 2 case '\x1a': buf[pos] = '\\' buf[pos+1] = 'Z' pos += 2 case '\'': // Single quotes can be escaped by adding another single quote. // https://docs.aws.amazon.com/athena/latest/ug/select.html // https://docs.aws.amazon.com/athena/latest/ug/data-types.html#data-types-considerations buf[pos] = '\'' buf[pos+1] = '\'' pos += 2 case '"': buf[pos] = '\\' buf[pos+1] = '"' pos += 2 case '\\': buf[pos] = '\\' buf[pos+1] = '\\' pos += 2 default: buf[pos] = c pos++ } } return buf[:pos] } // escapeStringBackslash is similar to escapeBytesBackslash but for string. func escapeStringBackslash(buf []byte, v string) []byte { return escapeBytesBackslash(buf, []byte(v)) } // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. // If cap(buf) is not enough, reallocate new buffer. func reserveBuffer(buf []byte, appendSize int) []byte { newSize := len(buf) + appendSize if cap(buf) < newSize { // Grow buffer exponentially newBuf := make([]byte, len(buf)*2+appendSize) copy(newBuf, buf) buf = newBuf } return buf[:newSize] } func namedValueToValue(named []driver.NamedValue) []driver.Value { args := make([]driver.Value, len(named)) for n, param := range named { args[n] = param.Value } return args } func valueToNamedValue(args []driver.Value) []driver.NamedValue { nameValues := make([]driver.NamedValue, len(args)) for i := 0; i < len(args); i++ { nameValues[i].Value = args[i] nameValues[i].Ordinal = i + 1 } return nameValues } func isQueryTimeOut(startOfStartQueryExecution time.Time, queryType string, serviceLimitOverride *ServiceLimitOverride) bool { ddlQueryTimeout := DDLQueryTimeout dmlQueryTimeout := DMLQueryTimeout if serviceLimitOverride != nil { if serviceLimitOverride.GetDDLQueryTimeout() > 0 { ddlQueryTimeout = serviceLimitOverride.GetDDLQueryTimeout() } if serviceLimitOverride.GetDMLQueryTimeout() > 0 { dmlQueryTimeout = serviceLimitOverride.GetDMLQueryTimeout() } } switch queryType { case "DDL": return time.Since(startOfStartQueryExecution) > time.Duration(ddlQueryTimeout)*time.Second case "DML": return time.Since(startOfStartQueryExecution) > time.Duration(dmlQueryTimeout)*time.Second case "UTILITY": return time.Since(startOfStartQueryExecution) > time.Duration(dmlQueryTimeout)*time.Second case "TIMEOUT_NOW": return true default: return time.Since(startOfStartQueryExecution) > time.Duration(ddlQueryTimeout)*time.Second } } // isQueryValid is to check the validity of Query, now only string length check. // https://docs.aws.amazon.com/athena/latest/ug/service-limits.html func isQueryValid(query string) bool { return len(query) < MAXQueryStringLength && len(query) > 4 } // GetFromEnvVal is to get environmental variable value by keys. // The return value is from whichever key is set according to the order in the slice. func GetFromEnvVal(keys []string) string { for _, k := range keys { if v := os.Getenv(k); len(v) != 0 { return v } } return "" } // printCost is to print query cost // https://aws.amazon.com/athena/pricing/ // getCost of 10MB: 5 / (1024. * 1024.) * 10 = 4.76837158203125e-05 func printCost(o *athena.GetQueryExecutionOutput) { if o == nil || o.QueryExecution == nil || o.QueryExecution.Statistics == nil { println("query cost: 0.0 USD, scanned data: 0 B, qid: NA") return } dataScannedBytes := o.QueryExecution.Statistics.DataScannedInBytes if dataScannedBytes == nil { println("query cost: 0.0 USD, scanned data: 0 B, qid: NA") } else if *dataScannedBytes == 0 { println("query cost: 0.0 USD, scanned data: 0 B, qid: " + *o.QueryExecution.QueryExecutionId) } else if *dataScannedBytes < 10*1024*1024 { fmt.Printf("query cost: %.20f USD, scanned data: %d B, qid: %s\n", getCost(*dataScannedBytes), *dataScannedBytes, *o.QueryExecution.QueryExecutionId) } else { fmt.Printf("query cost: %.20f USD, scanned data: %d B, qid: %s\n", getCost(*dataScannedBytes), *dataScannedBytes, *o.QueryExecution.QueryExecutionId) } } // getCost is return the USD cost upon data scanned in Bytes // https://aws.amazon.com/athena/pricing/ func getCost(data int64) float64 { if data == 0 { return 0.0 } else if data < int64(10*1024*1024) { return getPrice10MB() } else { return float64(data) * getPriceOneByte() } } var multiLineCommentPattern = regexp.MustCompile(`\/\*(.*)\*/\s*`) var oneLineCommentPattern = regexp.MustCompile(`(^\-\-[^\n]+|\s--[^\n]+)`) var getTableNamePattern = regexp.MustCompile(`(?i)\s+(?:from|join)\s+([\w.]+)`) var dualPattern = regexp.MustCompile(`from dual`) var qIDPattern = regexp.MustCompile(`^[0-9a-f-]{36}$`) // GetTableNamesInQuery is a pessimistic function to return tables involved in query in format of DB.TABLE // https://regoio.herokuapp.com/ // https://golang.org/pkg/regexp/syntax/ func GetTableNamesInQuery(query string) map[string]bool { query = multiLineCommentPattern.ReplaceAllString(query, "") query = oneLineCommentPattern.ReplaceAllString(query, "") matchedResults := getTableNamePattern.FindAllStringSubmatch(query, -1) tables := map[string]bool{} for _, matchedTableName := range matchedResults { if len(matchedTableName) == 2 { if strings.IndexByte(matchedTableName[1], '.') == -1 { tables[DefaultDBName+"."+matchedTableName[1]] = true } else { tables[matchedTableName[1]] = true } } } return tables } // GetTidySQL is to return a tidy SQL string func GetTidySQL(query string) string { query = multiLineCommentPattern.ReplaceAllString(query, "") query = oneLineCommentPattern.ReplaceAllString(query, "") stmt, err := sqlparser.Parse(query) if err == nil { q := sqlparser.String(stmt) // OtherRead represents a DESCRIBE, or EXPLAIN statement. // OtherAdmin represents a misc statement that relies on ADMIN privileges. if q == "otherread" || q == "otheradmin" || strings.Contains(q, " '$path' ") { return strings.Trim(query, " ") } query = dualPattern.ReplaceAllString(q, "") } return strings.Trim(query, " ") } // IsQID is to check if a query string is a Query ID // the hexadecimal Athena query ID like a44f8e61-4cbb-429a-b7ab-bea2c4a5caed // https://aws.amazon.com/premiumsupport/knowledge-center/access-download-athena-query-results/ func IsQID(q string) bool { return qIDPattern.MatchString(q) } // FormatString formats a string type query argument for Athena by escaping special characters and surrounding the // string with single quotes. Using FormatString allows for selective formatting of the query argument, if // typecasting or function calls are part of the query argument. // // Example usage: // query := "SELECT * FROM my_table WHERE description = ? AND created > ?" // // args := []any{ // aws.String(athenadriver.FormatString("The bunny's eating a carrot")), // aws.String(fmt.Sprintf("TIMESTAMP %s", athenadriver.FormatString("2024-07-01 00:00:00"))) // } func FormatString(v string) string { return fmt.Sprintf("'%s'", escapeBytesBackslash([]byte{}, []byte(v))) } // FormatBytes formats a byte slice query argument for Athena by escaping special characters and surrounding it with // single quotes. func FormatBytes(v []byte) []byte { buf := append([]byte{}, "_binary'"...) buf = escapeBytesBackslash(buf, v) buf = append(buf, '\'') return buf }