in go/connection.go [299:525]
func (c *Connection) QueryContext(ctx context.Context, query string, namedArgs []driver.NamedValue) (driver.Rows, error) {
var obs = c.connector.tracer
var pseudoCommand = ""
if strings.HasPrefix(query, "pc:") {
query = strings.Trim(query[3:], " ")
if pseudoCommand = PCGetQID; strings.HasPrefix(query, pseudoCommand+" ") {
query = strings.Trim(query[len(pseudoCommand):], " ")
} else if pseudoCommand = PCGetQIDStatus; strings.HasPrefix(query, pseudoCommand+" ") {
query = strings.Trim(query[len(pseudoCommand):], " ")
} else if pseudoCommand = PCStopQID; strings.HasPrefix(query, pseudoCommand+" ") {
query = strings.Trim(query[len(pseudoCommand):], " ")
} else if pseudoCommand = PCGetDriverVersion; strings.HasPrefix(query, pseudoCommand) {
return c.getHeaderlessSingleRowResultPage(ctx, DriverVersion)
} else {
return nil, fmt.Errorf("pseudo command " + query + "doesn't exist")
}
}
if c.connector.config.IsReadOnly() {
if !isReadOnlyStatement(query) {
obs.Scope().Counter(DriverName + ".failure.querycontext.writeviolation").Inc(1)
obs.Log(WarnLevel, "write db violation", zap.String("query", query))
return nil, fmt.Errorf("writing to Athena database is disallowed in read-only mode")
}
}
now := time.Now()
args := namedValueToValue(namedArgs)
queryWithPlaceholders := query // For parameterized queries
var err error
if len(namedArgs) > 0 {
query, err = c.interpolateParams(query, args)
if err != nil {
return nil, err
}
obs.Scope().Counter(DriverName + ".prepared.querycontext").Inc(1)
}
if !isQueryValid(query) {
return nil, ErrInvalidQuery
}
wg := c.connector.config.GetWorkgroup()
if wg.Name == "" {
wg.Name = DefaultWGName
} else if wg.Name != DefaultWGName {
athenaWG, err := getWG(ctx, c.athenaAPI, wg.Name)
if err != nil {
obs.Scope().Counter(DriverName + ".failure.querycontext.getwg").Inc(1)
obs.Log(WarnLevel, "Didn't find workgroup "+wg.Name+" due to: "+err.Error())
if reqerr, ok := err.(awserr.RequestFailure); !ok || reqerr.Message() != "WorkGroup is not found." {
return nil, err
}
if c.connector.config.IsWGRemoteCreationAllowed() {
err = wg.CreateWGRemotely(c.athenaAPI)
if err != nil {
obs.Scope().Counter(DriverName + ".failure.querycontext.createwgremotely").Inc(1)
return nil, err
}
obs.Log(DebugLevel, "workgroup "+wg.Name+" is created successfully.")
} else {
obs.Log(WarnLevel, "workgroup "+DefaultWGName+" is used for "+wg.Name+".")
return nil,
fmt.Errorf("workgroup %q doesn't exist and workgroup remote creation is disabled, due to: %v", wg.Name, err.Error())
}
} else {
if *athenaWG.State != athena.WorkGroupStateEnabled {
obs.Log(WarnLevel, "workgroup "+DefaultWGName+" is disabled.")
obs.Scope().Counter(DriverName + ".failure.querycontext.wgdisabled").Inc(1)
return nil, fmt.Errorf("workgroup %q is disabled", wg.Name)
}
obs.Log(DebugLevel, "workgroup "+DefaultWGName+" is enabled.")
}
}
timeWorkgroup := time.Since(now)
startOfStartQueryExecution := time.Now()
obs.Scope().Timer(DriverName + ".query.workgroup").Record(timeWorkgroup)
// case 1 - query directly using QID
if IsQID(query) {
if pseudoCommand == PCGetQIDStatus {
statusResp, err := c.athenaAPI.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
QueryExecutionId: aws.String(query),
})
if err != nil {
obs.Log(ErrorLevel, "GetQueryExecutionWithContext failed",
zap.String("workgroup", wg.Name),
zap.String("queryID", query),
zap.String("error", err.Error()))
obs.Scope().Counter(DriverName + ".failure.querycontext.getqueryexecutionwithcontext").Inc(1)
return nil, err
}
return c.getHeaderlessSingleRowResultPage(ctx, *statusResp.QueryExecution.Status.State)
}
if pseudoCommand == PCStopQID {
_, err := c.athenaAPI.StopQueryExecutionWithContext(context.Background(), &athena.StopQueryExecutionInput{
QueryExecutionId: aws.String(query),
})
if err != nil {
obs.Log(ErrorLevel, "StopQueryExecution failed",
zap.String("workgroup", wg.Name),
zap.String("queryID", query),
zap.String("query", query))
obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.failed").Inc(1)
return nil, err
}
return c.getHeaderlessSingleRowResultPage(ctx, "OK")
}
return c.cachedQuery(ctx, query)
}
// case 2 - TODO
executionParams, err := c.buildExecutionParams(args)
if err != nil {
return nil, err
}
resp, err := c.athenaAPI.StartQueryExecution(&athena.StartQueryExecutionInput{
QueryString: aws.String(queryWithPlaceholders),
ExecutionParameters: executionParams,
QueryExecutionContext: &athena.QueryExecutionContext{
Database: aws.String(c.connector.config.GetDB()),
},
ResultConfiguration: &athena.ResultConfiguration{
OutputLocation: aws.String(c.connector.config.GetOutputBucket()),
},
WorkGroup: aws.String(wg.Name),
})
if err != nil {
if pseudoCommand == PCGetQID {
if reqerr, ok := err.(awserr.RequestFailure); ok {
return c.getHeaderlessSingleRowResultPage(ctx, reqerr.RequestID())
}
}
return nil, err
}
timeStartQueryExecution := time.Since(startOfStartQueryExecution)
now = time.Now()
obs.Scope().Timer(DriverName + ".query.startqueryexecution").Record(timeStartQueryExecution)
queryID := *resp.QueryExecutionId
if pseudoCommand == PCGetQID {
return c.getHeaderlessSingleRowResultPage(ctx, queryID)
}
WAITING_FOR_RESULT:
for {
pollInterval := c.connector.config.GetResultPollIntervalSeconds()
statusResp, err := c.athenaAPI.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
})
if err != nil {
obs.Log(ErrorLevel, "GetQueryExecutionWithContext failed",
zap.String("workgroup", wg.Name),
zap.String("queryID", queryID),
zap.String("error", err.Error()))
obs.Scope().Counter(DriverName + ".failure.querycontext.getqueryexecutionwithcontext").Inc(1)
return nil, err
}
//statementType = statusResp.QueryExecution.StatementType
switch *statusResp.QueryExecution.Status.State {
case athena.QueryExecutionStateCancelled:
timeCanceled := time.Since(now)
obs.Log(ErrorLevel, "QueryExecutionStateCancelled",
zap.String("workgroup", wg.Name),
zap.String("queryID", queryID))
obs.Scope().Timer(DriverName + ".query.canceled").Record(timeCanceled)
if c.connector.config.IsMoneyWise() {
printCost(statusResp)
}
return nil, context.Canceled
case athena.QueryExecutionStateFailed:
reason := *statusResp.QueryExecution.Status.StateChangeReason
timeQueryExecutionStateFailed := time.Since(now)
obs.Log(ErrorLevel, "QueryExecutionStateFailed",
zap.String("workgroup", wg.Name),
zap.String("queryID", queryID),
zap.String("reason", reason))
obs.Scope().Timer(DriverName + ".query.queryexecutionstatefailed").Record(timeQueryExecutionStateFailed)
return nil, errors.New(reason)
case athena.QueryExecutionStateSucceeded:
if c.connector.config.IsMoneyWise() {
printCost(statusResp)
}
timeQueryExecutionStateSucceeded := time.Since(now)
obs.Scope().Timer(DriverName + ".query.queryexecutionstatesucceeded").Record(timeQueryExecutionStateSucceeded)
break WAITING_FOR_RESULT
// for athena.QueryExecutionStateQueued and athena.QueryExecutionStateRunning
default:
}
select {
case <-ctx.Done():
_, err := c.athenaAPI.
StopQueryExecutionWithContext(context.Background(), &athena.StopQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
})
if err != nil {
obs.Log(ErrorLevel, "StopQueryExecution failed",
zap.String("workgroup", wg.Name),
zap.String("queryID", queryID),
zap.String("query", query))
obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.failed").Inc(1)
return nil, err
}
if c.connector.config.IsMoneyWise() {
statusRespFinal, _ := c.athenaAPI.GetQueryExecutionWithContext(context.Background(), &athena.GetQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
})
printCost(statusRespFinal)
}
obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.succeeded").Inc(1)
timeStopQueryExecution := time.Since(now)
obs.Scope().Timer(DriverName + ".query.StopQueryExecution").Record(timeStopQueryExecution)
obs.Log(ErrorLevel, "query canceled", zap.String("queryID", queryID))
return nil, ctx.Err()
case <-time.After(pollInterval):
if isQueryTimeOut(startOfStartQueryExecution, *statusResp.QueryExecution.StatementType, c.connector.config.GetServiceLimitOverride()) {
obs.Log(ErrorLevel, "Query timeout failure",
zap.String("workgroup", wg.Name),
zap.String("queryID", queryID),
zap.String("query", query))
obs.Scope().Counter(DriverName + ".failure.querycontext.timeout").Inc(1)
return nil, ErrQueryTimeout
}
continue
}
}
return NewRows(ctx, c.athenaAPI, queryID, c.connector.config, obs)
}