func()

in go/connection.go [299:525]


func (c *Connection) QueryContext(ctx context.Context, query string, namedArgs []driver.NamedValue) (driver.Rows, error) {
	var obs = c.connector.tracer
	var pseudoCommand = ""
	if strings.HasPrefix(query, "pc:") {
		query = strings.Trim(query[3:], " ")
		if pseudoCommand = PCGetQID; strings.HasPrefix(query, pseudoCommand+" ") {
			query = strings.Trim(query[len(pseudoCommand):], " ")
		} else if pseudoCommand = PCGetQIDStatus; strings.HasPrefix(query, pseudoCommand+" ") {
			query = strings.Trim(query[len(pseudoCommand):], " ")
		} else if pseudoCommand = PCStopQID; strings.HasPrefix(query, pseudoCommand+" ") {
			query = strings.Trim(query[len(pseudoCommand):], " ")
		} else if pseudoCommand = PCGetDriverVersion; strings.HasPrefix(query, pseudoCommand) {
			return c.getHeaderlessSingleRowResultPage(ctx, DriverVersion)
		} else {
			return nil, fmt.Errorf("pseudo command " + query + "doesn't exist")
		}
	}
	if c.connector.config.IsReadOnly() {
		if !isReadOnlyStatement(query) {
			obs.Scope().Counter(DriverName + ".failure.querycontext.writeviolation").Inc(1)
			obs.Log(WarnLevel, "write db violation", zap.String("query", query))
			return nil, fmt.Errorf("writing to Athena database is disallowed in read-only mode")
		}
	}
	now := time.Now()
	args := namedValueToValue(namedArgs)
	queryWithPlaceholders := query // For parameterized queries
	var err error
	if len(namedArgs) > 0 {
		query, err = c.interpolateParams(query, args)
		if err != nil {
			return nil, err
		}
		obs.Scope().Counter(DriverName + ".prepared.querycontext").Inc(1)
	}
	if !isQueryValid(query) {
		return nil, ErrInvalidQuery
	}
	wg := c.connector.config.GetWorkgroup()
	if wg.Name == "" {
		wg.Name = DefaultWGName
	} else if wg.Name != DefaultWGName {
		athenaWG, err := getWG(ctx, c.athenaAPI, wg.Name)
		if err != nil {
			obs.Scope().Counter(DriverName + ".failure.querycontext.getwg").Inc(1)
			obs.Log(WarnLevel, "Didn't find workgroup "+wg.Name+" due to: "+err.Error())
			if reqerr, ok := err.(awserr.RequestFailure); !ok || reqerr.Message() != "WorkGroup is not found." {
				return nil, err
			}
			if c.connector.config.IsWGRemoteCreationAllowed() {
				err = wg.CreateWGRemotely(c.athenaAPI)
				if err != nil {
					obs.Scope().Counter(DriverName + ".failure.querycontext.createwgremotely").Inc(1)
					return nil, err
				}
				obs.Log(DebugLevel, "workgroup "+wg.Name+" is created successfully.")
			} else {
				obs.Log(WarnLevel, "workgroup "+DefaultWGName+" is used for "+wg.Name+".")
				return nil,
					fmt.Errorf("workgroup %q doesn't exist and workgroup remote creation is disabled, due to: %v", wg.Name, err.Error())
			}
		} else {
			if *athenaWG.State != athena.WorkGroupStateEnabled {
				obs.Log(WarnLevel, "workgroup "+DefaultWGName+" is disabled.")
				obs.Scope().Counter(DriverName + ".failure.querycontext.wgdisabled").Inc(1)
				return nil, fmt.Errorf("workgroup %q is disabled", wg.Name)
			}
			obs.Log(DebugLevel, "workgroup "+DefaultWGName+" is enabled.")
		}
	}

	timeWorkgroup := time.Since(now)
	startOfStartQueryExecution := time.Now()
	obs.Scope().Timer(DriverName + ".query.workgroup").Record(timeWorkgroup)

	// case 1 - query directly using QID
	if IsQID(query) {
		if pseudoCommand == PCGetQIDStatus {
			statusResp, err := c.athenaAPI.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
				QueryExecutionId: aws.String(query),
			})
			if err != nil {
				obs.Log(ErrorLevel, "GetQueryExecutionWithContext failed",
					zap.String("workgroup", wg.Name),
					zap.String("queryID", query),
					zap.String("error", err.Error()))
				obs.Scope().Counter(DriverName + ".failure.querycontext.getqueryexecutionwithcontext").Inc(1)
				return nil, err
			}
			return c.getHeaderlessSingleRowResultPage(ctx, *statusResp.QueryExecution.Status.State)
		}
		if pseudoCommand == PCStopQID {
			_, err := c.athenaAPI.StopQueryExecutionWithContext(context.Background(), &athena.StopQueryExecutionInput{
				QueryExecutionId: aws.String(query),
			})
			if err != nil {
				obs.Log(ErrorLevel, "StopQueryExecution failed",
					zap.String("workgroup", wg.Name),
					zap.String("queryID", query),
					zap.String("query", query))
				obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.failed").Inc(1)
				return nil, err
			}
			return c.getHeaderlessSingleRowResultPage(ctx, "OK")
		}
		return c.cachedQuery(ctx, query)
	}

	//  case 2 - TODO
	executionParams, err := c.buildExecutionParams(args)
	if err != nil {
		return nil, err
	}
	resp, err := c.athenaAPI.StartQueryExecution(&athena.StartQueryExecutionInput{
		QueryString:         aws.String(queryWithPlaceholders),
		ExecutionParameters: executionParams,
		QueryExecutionContext: &athena.QueryExecutionContext{
			Database: aws.String(c.connector.config.GetDB()),
		},
		ResultConfiguration: &athena.ResultConfiguration{
			OutputLocation: aws.String(c.connector.config.GetOutputBucket()),
		},
		WorkGroup: aws.String(wg.Name),
	})
	if err != nil {
		if pseudoCommand == PCGetQID {
			if reqerr, ok := err.(awserr.RequestFailure); ok {
				return c.getHeaderlessSingleRowResultPage(ctx, reqerr.RequestID())
			}
		}
		return nil, err
	}

	timeStartQueryExecution := time.Since(startOfStartQueryExecution)
	now = time.Now()
	obs.Scope().Timer(DriverName + ".query.startqueryexecution").Record(timeStartQueryExecution)

	queryID := *resp.QueryExecutionId
	if pseudoCommand == PCGetQID {
		return c.getHeaderlessSingleRowResultPage(ctx, queryID)
	}
WAITING_FOR_RESULT:
	for {
		pollInterval := c.connector.config.GetResultPollIntervalSeconds()
		statusResp, err := c.athenaAPI.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
			QueryExecutionId: aws.String(queryID),
		})
		if err != nil {
			obs.Log(ErrorLevel, "GetQueryExecutionWithContext failed",
				zap.String("workgroup", wg.Name),
				zap.String("queryID", queryID),
				zap.String("error", err.Error()))
			obs.Scope().Counter(DriverName + ".failure.querycontext.getqueryexecutionwithcontext").Inc(1)
			return nil, err
		}
		//statementType = statusResp.QueryExecution.StatementType
		switch *statusResp.QueryExecution.Status.State {
		case athena.QueryExecutionStateCancelled:
			timeCanceled := time.Since(now)
			obs.Log(ErrorLevel, "QueryExecutionStateCancelled",
				zap.String("workgroup", wg.Name),
				zap.String("queryID", queryID))
			obs.Scope().Timer(DriverName + ".query.canceled").Record(timeCanceled)
			if c.connector.config.IsMoneyWise() {
				printCost(statusResp)
			}
			return nil, context.Canceled
		case athena.QueryExecutionStateFailed:
			reason := *statusResp.QueryExecution.Status.StateChangeReason
			timeQueryExecutionStateFailed := time.Since(now)
			obs.Log(ErrorLevel, "QueryExecutionStateFailed",
				zap.String("workgroup", wg.Name),
				zap.String("queryID", queryID),
				zap.String("reason", reason))
			obs.Scope().Timer(DriverName + ".query.queryexecutionstatefailed").Record(timeQueryExecutionStateFailed)
			return nil, errors.New(reason)
		case athena.QueryExecutionStateSucceeded:
			if c.connector.config.IsMoneyWise() {
				printCost(statusResp)
			}
			timeQueryExecutionStateSucceeded := time.Since(now)
			obs.Scope().Timer(DriverName + ".query.queryexecutionstatesucceeded").Record(timeQueryExecutionStateSucceeded)
			break WAITING_FOR_RESULT
		// for athena.QueryExecutionStateQueued and athena.QueryExecutionStateRunning
		default:
		}

		select {
		case <-ctx.Done():
			_, err := c.athenaAPI.
				StopQueryExecutionWithContext(context.Background(), &athena.StopQueryExecutionInput{
					QueryExecutionId: aws.String(queryID),
				})
			if err != nil {
				obs.Log(ErrorLevel, "StopQueryExecution failed",
					zap.String("workgroup", wg.Name),
					zap.String("queryID", queryID),
					zap.String("query", query))
				obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.failed").Inc(1)
				return nil, err
			}
			if c.connector.config.IsMoneyWise() {
				statusRespFinal, _ := c.athenaAPI.GetQueryExecutionWithContext(context.Background(), &athena.GetQueryExecutionInput{
					QueryExecutionId: aws.String(queryID),
				})
				printCost(statusRespFinal)
			}
			obs.Scope().Counter(DriverName + ".failure.querycontext.stopqueryexecution.succeeded").Inc(1)
			timeStopQueryExecution := time.Since(now)
			obs.Scope().Timer(DriverName + ".query.StopQueryExecution").Record(timeStopQueryExecution)
			obs.Log(ErrorLevel, "query canceled", zap.String("queryID", queryID))
			return nil, ctx.Err()
		case <-time.After(pollInterval):
			if isQueryTimeOut(startOfStartQueryExecution, *statusResp.QueryExecution.StatementType, c.connector.config.GetServiceLimitOverride()) {
				obs.Log(ErrorLevel, "Query timeout failure",
					zap.String("workgroup", wg.Name),
					zap.String("queryID", queryID),
					zap.String("query", query))
				obs.Scope().Counter(DriverName + ".failure.querycontext.timeout").Inc(1)
				return nil, ErrQueryTimeout
			}
			continue
		}
	}

	return NewRows(ctx, c.athenaAPI, queryID, c.connector.config, obs)
}