go/connection.go (473 lines of code) (raw):

// Copyright (c) 2022 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package athenadriver import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "strconv" "strings" "time" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/athena/athenaiface" "go.uber.org/zap" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/athena" ) // timestampFormatDriverMicro is the string format we transform Go time.Time objects into. This is not meant for // TIMESTAMP columns, as Athena timestamp columns only have a millisecond granularity. const timestampFormatDriverMicro = "2006-01-02 15:04:05.000000" // Connection is a connection to AWS Athena. It is not used concurrently by multiple goroutines. // Connection is assumed to be stateful. type Connection struct { athenaAPI athenaiface.AthenaAPI connector *SQLConnector numInput int } // buildExecutionParams converts Go data types into strings for query arguments in parameterized queries. func (c *Connection) buildExecutionParams(args []driver.Value) ([]*string, error) { executionParams := []*string{} for _, arg := range args { if arg == nil { val := "NULL" executionParams = append(executionParams, aws.String(val)) continue } // type switches of arg to handle different query parameter types val := "" switch v := arg.(type) { case int64: val = strconv.FormatInt(v, 10) case uint64: val = strconv.FormatUint(v, 10) case float64: val = strconv.FormatFloat(v, 'g', -1, 64) case bool: if v { val = "1" } else { val = "0" } case time.Time: // Note: time.Time objects are transformed into strings for a STRING/CHAR/VARCHAR type column. // To maintain compatibility with the current interpolateParams() behavior, this function produces a string // up to microsecond granularity, which Athena does not support in TIMESTAMP columns (up to milliseconds). // For DATE/TIME/TIMESTAMP, it is better to pass in string arguments with a typecast. Refer to the string // case below. // Matches interpolateParams() behavior. val = "'0000-00-00'" // Special-cased. if !v.IsZero() { v := v.In(time.UTC) v = v.Add(time.Nanosecond * 500) // To round under microsecond dateFormat := timestampFormatDriverMicro if v.Nanosecond()/1000 == 0 { // Omit microseconds if that part is zero dateFormat = time.DateTime } val = fmt.Sprintf("'%s'", v.Format(dateFormat)) } case []byte: // Note: Different from interpolateParams() behavior. // Like the string case below, enclosing in single quotes would prevent typecasting or function calls in // execution parameters. Prior to passing in query arguments, Format* functions in utils.go can be used. val = string(v) case string: // Note: Different from interpolateParams() behavior. // For parameterized queries, typecasting or function calls go in the execution parameters. For example, // `WHERE created = TIMESTAMP '2024-07-01 00:00:00'` should be formatted as: `WHERE created = ?` (query) and // `TIMESTAMP '2024-07-01 00:00:00.000'` (arg). Therefore, we cannot simply enclose the full string with // single quotes here. Users should use the Format* functions in utils.go to format input string arguments. val = v default: return []*string{}, ErrQueryUnknownType } executionParams = append(executionParams, aws.String(val)) } return executionParams, nil } func (c *Connection) interpolateParams(query string, args []driver.Value) (string, error) { c.numInput = len(args) // Number of ? should be same to len(args) if strings.Count(query, "?") != c.numInput { return "", ErrInvalidQuery } queryBuffer := make([]byte, MAXQueryStringLength) queryBuffer = queryBuffer[:0] argPos := 0 for i := 0; i < len(query); i++ { q := strings.IndexByte(query[i:], '?') if q == -1 { queryBuffer = append(queryBuffer, query[i:]...) break } queryBuffer = append(queryBuffer, query[i:i+q]...) i += q arg := args[argPos] argPos++ if arg == nil { queryBuffer = append(queryBuffer, "NULL"...) continue } // type switches of arg to handle different query parameter types switch v := arg.(type) { case int64: queryBuffer = strconv.AppendInt(queryBuffer, v, 10) case uint64: queryBuffer = strconv.AppendUint(queryBuffer, v, 10) case float64: queryBuffer = strconv.AppendFloat(queryBuffer, v, 'g', -1, 64) case bool: if v { queryBuffer = append(queryBuffer, '1') } else { queryBuffer = append(queryBuffer, '0') } case time.Time: if v.IsZero() { queryBuffer = append(queryBuffer, "'0000-00-00'"...) } else { v := v.In(time.UTC) v = v.Add(time.Nanosecond * 500) // To round under microsecond year := v.Year() year100 := year / 100 year1 := year % 100 month := v.Month() day := v.Day() hour := v.Hour() minute := v.Minute() second := v.Second() micro := v.Nanosecond() / 1000 queryBuffer = append(queryBuffer, []byte{ '\'', digits10[year100], digits01[year100], digits10[year1], digits01[year1], '-', digits10[month], digits01[month], '-', digits10[day], digits01[day], ' ', digits10[hour], digits01[hour], ':', digits10[minute], digits01[minute], ':', digits10[second], digits01[second], }...) if micro != 0 { micro10000 := micro / 10000 micro100 := micro / 100 % 100 micro1 := micro % 100 queryBuffer = append(queryBuffer, []byte{ '.', digits10[micro10000], digits01[micro10000], digits10[micro100], digits01[micro100], digits10[micro1], digits01[micro1], }...) } queryBuffer = append(queryBuffer, '\'') } case []byte: queryBuffer = append(queryBuffer, "_binary'"...) queryBuffer = escapeBytesBackslash(queryBuffer, v) queryBuffer = append(queryBuffer, '\'') case string: queryBuffer = append(queryBuffer, '\'') queryBuffer = escapeStringBackslash(queryBuffer, v) queryBuffer = append(queryBuffer, '\'') default: return "", ErrQueryUnknownType } if len(queryBuffer)+4 > 10*MAXQueryStringLength { return "", ErrQueryBufferOF } } return string(queryBuffer), nil } // CheckNamedValue is to implement interface driver.NamedValueChecker. func (c *Connection) CheckNamedValue(nv *driver.NamedValue) (err error) { nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) return } // ExecContext executes a query that doesn't return rows, such as an INSERT or UPDATE. func (c *Connection) ExecContext(ctx context.Context, query string, namedArgs []driver.NamedValue) (driver.Result, error) { var obs = c.connector.tracer var err error args := namedValueToValue(namedArgs) if len(namedArgs) > 0 { query, err = c.interpolateParams(query, args) if err != nil { return nil, err } obs.Scope().Counter(DriverName + ".execcontext").Inc(1) } if !isQueryValid(query) { return nil, ErrInvalidQuery } rows, err := c.QueryContext(ctx, query, []driver.NamedValue{}) if err != nil { return nil, err } var rowAffected int64 = 0 r := rows.(*Rows) if r != nil && r.ResultOutput != nil && r.ResultOutput.UpdateCount != nil { rowAffected = *r.ResultOutput.UpdateCount } var lastInsertedID int64 = -1 result := AthenaResult{ lastInsertedID: lastInsertedID, rowAffected: rowAffected, } return result, nil } func (c *Connection) cachedQuery(ctx context.Context, QID string) (driver.Rows, error) { if c.connector.config.IsMoneyWise() { dataScanned := int64(0) printCost(&athena.GetQueryExecutionOutput{ QueryExecution: &athena.QueryExecution{ QueryExecutionId: &QID, Statistics: &athena.QueryExecutionStatistics{ DataScannedInBytes: &dataScanned, }, }, }) } wg := c.connector.config.GetWorkgroup() if wg.Name == "" { wg.Name = DefaultWGName } return NewRows(ctx, c.athenaAPI, QID, c.connector.config, c.connector.tracer) } func (c *Connection) getHeaderlessSingleRowResultPage(ctx context.Context, qid string) (driver.Rows, error) { r, err := NewNonOpsRows(ctx, c.athenaAPI, qid, c.connector.config, c.connector.tracer) colName := "_col0" columnNames := []*string{&colName} columnTypes := []string{"string"} data := make([][]*string, 1) data[0] = []*string{&qid} r.ResultOutput = newHeaderlessResultPage(columnNames, columnTypes, data) return r, err } // QueryContext is implemented to be called by `DB.Query` (QueryerContext interface). // // "QueryerContext is an optional interface that may be implemented by a Conn. // If a Conn does not implement QueryerContext, the sql package's DB.Query // will fall back to Queryer; if the Conn does not implement Queryer either, // DB.Query will first prepare a query, execute the statement, and then // close the statement." // // With QueryContext implemented, we don't need Queryer. // QueryerContext must honor the context timeout and return when the context is canceled. func (c *Connection) QueryContext(ctx context.Context, query string, namedArgs []driver.NamedValue) (driver.Rows, error) { var obs = c.connector.tracer var pseudoCommand = "" if strings.HasPrefix(query, "pc:") { query = strings.Trim(query[3:], " ") if pseudoCommand = PCGetQID; strings.HasPrefix(query, pseudoCommand+" ") { query = strings.Trim(query[len(pseudoCommand):], " ") } else if pseudoCommand = PCGetQIDStatus; strings.HasPrefix(query, pseudoCommand+" ") { query = strings.Trim(query[len(pseudoCommand):], " ") } else if pseudoCommand = PCStopQID; strings.HasPrefix(query, pseudoCommand+" ") { query = strings.Trim(query[len(pseudoCommand):], " ") } else if pseudoCommand = PCGetDriverVersion; strings.HasPrefix(query, pseudoCommand) { return c.getHeaderlessSingleRowResultPage(ctx, DriverVersion) } else { return nil, fmt.Errorf("pseudo command " + query + "doesn't exist") } } if c.connector.config.IsReadOnly() { if !isReadOnlyStatement(query) { obs.Scope().Counter(DriverName + ".failure.querycontext.writeviolation").Inc(1) obs.Log(WarnLevel, "write db violation", zap.String("query", query)) return nil, fmt.Errorf("writing to Athena database is disallowed in read-only mode") } } now := time.Now() args := namedValueToValue(namedArgs) queryWithPlaceholders := query // For parameterized queries var err error if len(namedArgs) > 0 { query, err = c.interpolateParams(query, args) if err != nil { return nil, err } obs.Scope().Counter(DriverName + ".prepared.querycontext").Inc(1) } if !isQueryValid(query) { return nil, ErrInvalidQuery } wg := c.connector.config.GetWorkgroup() if wg.Name == "" { wg.Name = DefaultWGName } else if wg.Name != DefaultWGName { athenaWG, err := getWG(ctx, c.athenaAPI, wg.Name) if err != nil { obs.Scope().Counter(DriverName + ".failure.querycontext.getwg").Inc(1) obs.Log(WarnLevel, "Didn't find workgroup "+wg.Name+" due to: "+err.Error()) if reqerr, ok := err.(awserr.RequestFailure); !ok || reqerr.Message() != "WorkGroup is not found." { return nil, err } if c.connector.config.IsWGRemoteCreationAllowed() { err = wg.CreateWGRemotely(c.athenaAPI) if err != nil { obs.Scope().Counter(DriverName + ".failure.querycontext.createwgremotely").Inc(1) return nil, err } obs.Log(DebugLevel, "workgroup "+wg.Name+" is created successfully.") } else { obs.Log(WarnLevel, "workgroup "+DefaultWGName+" is used for "+wg.Name+".") return nil, fmt.Errorf("workgroup %q doesn't exist and workgroup remote creation is disabled, due to: %v", wg.Name, err.Error()) } } else { if *athenaWG.State != athena.WorkGroupStateEnabled { obs.Log(WarnLevel, "workgroup "+DefaultWGName+" is disabled.") obs.Scope().Counter(DriverName + ".failure.querycontext.wgdisabled").Inc(1) return nil, fmt.Errorf("workgroup %q is disabled", wg.Name) } obs.Log(DebugLevel, "workgroup "+DefaultWGName+" is enabled.") } } timeWorkgroup := time.Since(now) startOfStartQueryExecution := time.Now() obs.Scope().Timer(DriverName + ".query.workgroup").Record(timeWorkgroup) // case 1 - query directly using QID if IsQID(query) { if pseudoCommand == PCGetQIDStatus { statusResp, err := c.athenaAPI.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{ QueryExecutionId: aws.String(query), }) if err != nil { obs.Log(ErrorLevel, "GetQueryExecutionWithContext failed", zap.String("workgroup", wg.Name), zap.String("queryID", query), zap.String("error", err.Error())) obs.Scope().Counter(DriverName + ".failure.querycontext.getqueryexecutionwithcontext").Inc(1) return nil, err } return c.getHeaderlessSingleRowResultPage(ctx, *statusResp.QueryExecution.Status.State) } if pseudoCommand == PCStopQID { _, err := c.athenaAPI.StopQueryExecutionWithContext(context.Background(), &athena.StopQueryExecutionInput{ QueryExecutionId: aws.String(query), }) if err != nil { obs.Log(ErrorLevel, "StopQueryExecution failed", zap.String("workgroup", wg.Name), zap.String("queryID", query), zap.String("query", query)) obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.failed").Inc(1) return nil, err } return c.getHeaderlessSingleRowResultPage(ctx, "OK") } return c.cachedQuery(ctx, query) } // case 2 - TODO executionParams, err := c.buildExecutionParams(args) if err != nil { return nil, err } resp, err := c.athenaAPI.StartQueryExecution(&athena.StartQueryExecutionInput{ QueryString: aws.String(queryWithPlaceholders), ExecutionParameters: executionParams, QueryExecutionContext: &athena.QueryExecutionContext{ Database: aws.String(c.connector.config.GetDB()), }, ResultConfiguration: &athena.ResultConfiguration{ OutputLocation: aws.String(c.connector.config.GetOutputBucket()), }, WorkGroup: aws.String(wg.Name), }) if err != nil { if pseudoCommand == PCGetQID { if reqerr, ok := err.(awserr.RequestFailure); ok { return c.getHeaderlessSingleRowResultPage(ctx, reqerr.RequestID()) } } return nil, err } timeStartQueryExecution := time.Since(startOfStartQueryExecution) now = time.Now() obs.Scope().Timer(DriverName + ".query.startqueryexecution").Record(timeStartQueryExecution) queryID := *resp.QueryExecutionId if pseudoCommand == PCGetQID { return c.getHeaderlessSingleRowResultPage(ctx, queryID) } WAITING_FOR_RESULT: for { pollInterval := c.connector.config.GetResultPollIntervalSeconds() statusResp, err := c.athenaAPI.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{ QueryExecutionId: aws.String(queryID), }) if err != nil { obs.Log(ErrorLevel, "GetQueryExecutionWithContext failed", zap.String("workgroup", wg.Name), zap.String("queryID", queryID), zap.String("error", err.Error())) obs.Scope().Counter(DriverName + ".failure.querycontext.getqueryexecutionwithcontext").Inc(1) return nil, err } //statementType = statusResp.QueryExecution.StatementType switch *statusResp.QueryExecution.Status.State { case athena.QueryExecutionStateCancelled: timeCanceled := time.Since(now) obs.Log(ErrorLevel, "QueryExecutionStateCancelled", zap.String("workgroup", wg.Name), zap.String("queryID", queryID)) obs.Scope().Timer(DriverName + ".query.canceled").Record(timeCanceled) if c.connector.config.IsMoneyWise() { printCost(statusResp) } return nil, context.Canceled case athena.QueryExecutionStateFailed: reason := *statusResp.QueryExecution.Status.StateChangeReason timeQueryExecutionStateFailed := time.Since(now) obs.Log(ErrorLevel, "QueryExecutionStateFailed", zap.String("workgroup", wg.Name), zap.String("queryID", queryID), zap.String("reason", reason)) obs.Scope().Timer(DriverName + ".query.queryexecutionstatefailed").Record(timeQueryExecutionStateFailed) return nil, errors.New(reason) case athena.QueryExecutionStateSucceeded: if c.connector.config.IsMoneyWise() { printCost(statusResp) } timeQueryExecutionStateSucceeded := time.Since(now) obs.Scope().Timer(DriverName + ".query.queryexecutionstatesucceeded").Record(timeQueryExecutionStateSucceeded) break WAITING_FOR_RESULT // for athena.QueryExecutionStateQueued and athena.QueryExecutionStateRunning default: } select { case <-ctx.Done(): _, err := c.athenaAPI. StopQueryExecutionWithContext(context.Background(), &athena.StopQueryExecutionInput{ QueryExecutionId: aws.String(queryID), }) if err != nil { obs.Log(ErrorLevel, "StopQueryExecution failed", zap.String("workgroup", wg.Name), zap.String("queryID", queryID), zap.String("query", query)) obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.failed").Inc(1) return nil, err } if c.connector.config.IsMoneyWise() { statusRespFinal, _ := c.athenaAPI.GetQueryExecutionWithContext(context.Background(), &athena.GetQueryExecutionInput{ QueryExecutionId: aws.String(queryID), }) printCost(statusRespFinal) } obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.succeeded").Inc(1) timeStopQueryExecution := time.Since(now) obs.Scope().Timer(DriverName + ".query.StopQueryExecution").Record(timeStopQueryExecution) obs.Log(ErrorLevel, "query canceled", zap.String("queryID", queryID)) return nil, ctx.Err() case <-time.After(pollInterval): if isQueryTimeOut(startOfStartQueryExecution, *statusResp.QueryExecution.StatementType, c.connector.config.GetServiceLimitOverride()) { obs.Log(ErrorLevel, "Query timeout failure", zap.String("workgroup", wg.Name), zap.String("queryID", queryID), zap.String("query", query)) obs.Scope().Counter(DriverName + ".failure.querycontext.timeout").Inc(1) return nil, ErrQueryTimeout } continue } } return NewRows(ctx, c.athenaAPI, queryID, c.connector.config, obs) } // Ping implements driver.Pinger interface. // Ping is a good first step in a health check: If the Ping succeeds, // make a simple query, then make a complex query which depends on proper // DB scheme. This will make troubleshooting simpler as the error now is: // "We've got network connectivity, we can Ping the DB, so we have valid // credentials for a SELECT xxx; but ...". func (c *Connection) Ping(ctx context.Context) error { rows, err := c.QueryContext(ctx, "SELECT 1", nil) if err != nil { return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger } defer rows.Close() return nil } // Prepare is inherited from Conn interface. func (c *Connection) Prepare(query string) (driver.Stmt, error) { if !isQueryValid(query) { return nil, ErrInvalidQuery } stmt := &Statement{ connection: c, query: query, closed: false, numInput: strings.Count(query, "?"), } return stmt, nil } // Begin is from Conn interface, but no implementation for AWS Athena. func (c *Connection) Begin() (driver.Tx, error) { return nil, ErrAthenaTransactionUnsupported } // BeginTx is to replace Begin as it is deprecated. func (c *Connection) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { return nil, ErrAthenaTransactionUnsupported } // Close is from Conn interface, but no implementation for AWS Athena. // Because the sql package maintains a free pool of // connections and only calls Close when there's a surplus of // idle connections, it shouldn't be necessary for drivers to // do their own connection caching. func (c *Connection) Close() error { c.connector = nil c.athenaAPI = nil c.numInput = -1 return nil } var _ driver.QueryerContext = (*Connection)(nil) var _ driver.ExecerContext = (*Connection)(nil)