awswrangler/athena/_executions.py (108 lines of code) (raw):
"""Query executions Module for Amazon Athena."""
from __future__ import annotations
import logging
import time
from typing import (
Any,
Dict,
cast,
)
import boto3
import botocore
from typing_extensions import Literal
from awswrangler import _utils, exceptions, typing
from awswrangler._config import apply_configs
from ._cache import _CacheInfo, _check_for_cached_results
from ._utils import (
_QUERY_FINAL_STATES,
_QUERY_WAIT_POLLING_DELAY,
_apply_formatter,
_get_workgroup_config,
_start_query_execution,
_WorkGroupConfig,
)
_logger: logging.Logger = logging.getLogger(__name__)
@apply_configs
def start_query_execution(
sql: str,
database: str | None = None,
s3_output: str | None = None,
workgroup: str = "primary",
encryption: str | None = None,
kms_key: str | None = None,
params: dict[str, Any] | list[str] | None = None,
paramstyle: Literal["qmark", "named"] = "named",
boto3_session: boto3.Session | None = None,
client_request_token: str | None = None,
athena_cache_settings: typing.AthenaCacheSettings | None = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
data_source: str | None = None,
wait: bool = False,
) -> str | dict[str, Any]:
"""Start a SQL Query against AWS Athena.
Note
----
Create the default Athena bucket if it doesn't exist and s3_output is None.
(E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
Parameters
----------
sql
SQL query.
database
AWS Glue/Athena database name.
s3_output
AWS S3 path.
workgroup
Athena workgroup. Primary by default.
encryption
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
kms_key
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
params
Parameters that will be used for constructing the SQL query.
Only named or question mark parameters are supported.
The parameter style needs to be specified in the ``paramstyle`` parameter.
For ``paramstyle="named"``, this value needs to be a dictionary.
The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain
``:name``.
The formatter will be applied client-side in this scenario.
For ``paramstyle="qmark"``, this value needs to be a list of strings.
The formatter will be applied server-side.
The values are applied sequentially to the parameters in the query in the order in which the parameters occur.
paramstyle
Determines the style of ``params``.
Possible values are:
- ``named``
- ``qmark``
boto3_session
The default boto3 session will be used if **boto3_session** receive ``None``.
client_request_token
A unique case-sensitive string used to ensure the request to create the query is idempotent (executes only once).
If another StartQueryExecution request is received, the same response is returned and another query is not created.
If a parameter has changed, for example, the QueryString , an error is returned.
If you pass the same client_request_token value with different parameters the query fails with error
message "Idempotent parameters do not match". Use this only with ctas_approach=False and unload_approach=False
and disabled cache.
athena_cache_settings
Parameters of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections,
max_remote_cache_entries, and max_local_cache_entries.
AthenaCacheSettings is a `TypedDict`, meaning the passed parameter can be instantiated either as an
instance of AthenaCacheSettings or as a regular Python dict.
If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`,
`keep_files` and `ctas_temp_table_name` params.
If reading cached data fails for any reason, execution falls back to the usual query run path.
athena_query_wait_polling_delay
Interval in seconds for how often the function will check if the Athena query has completed.
data_source
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
wait
Indicates whether to wait for the query to finish and return a dictionary with the query execution response.
Returns
-------
Query execution ID if `wait` is set to `False`, dictionary with the get_query_execution response otherwise.
Examples
--------
Querying into the default data source (Amazon s3 - 'AwsDataCatalog')
>>> import awswrangler as wr
>>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...')
Querying into another data source (PostgreSQL, Redshift, etc)
>>> import awswrangler as wr
>>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...', data_source='...')
"""
# Substitute query parameters if applicable
sql, execution_params = _apply_formatter(sql, params, paramstyle)
_logger.debug("Executing query:\n%s", sql)
if not client_request_token:
cache_info: _CacheInfo = _check_for_cached_results(
sql=sql,
boto3_session=boto3_session,
workgroup=workgroup,
athena_cache_settings=athena_cache_settings,
)
_logger.debug("Cache info:\n%s", cache_info)
if not client_request_token and cache_info.has_valid_cache and cache_info.query_execution_id is not None:
query_execution_id = cache_info.query_execution_id
_logger.debug("Valid cache found. Retrieving...")
else:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
query_execution_id = _start_query_execution(
sql=sql,
wg_config=wg_config,
database=database,
data_source=data_source,
s3_output=s3_output,
workgroup=workgroup,
encryption=encryption,
kms_key=kms_key,
execution_params=execution_params,
client_request_token=client_request_token,
boto3_session=boto3_session,
)
if wait:
return wait_query(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
)
return query_execution_id
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None:
"""Stop a query execution.
Requires you to have access to the workgroup in which the query ran.
Parameters
----------
query_execution_id
Athena query execution ID.
boto3_session
The default boto3 session will be used if **boto3_session** receive ``None``.
Examples
--------
>>> import awswrangler as wr
>>> wr.athena.stop_query_execution(query_execution_id='query-execution-id')
"""
client_athena = _utils.client(service_name="athena", session=boto3_session)
client_athena.stop_query_execution(QueryExecutionId=query_execution_id)
@apply_configs
def wait_query(
query_execution_id: str,
boto3_session: boto3.Session | None = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
) -> dict[str, Any]:
"""Wait for the query end.
Parameters
----------
query_execution_id
Athena query execution ID.
boto3_session
The default boto3 session will be used if **boto3_session** receive ``None``.
athena_query_wait_polling_delay
Interval in seconds for how often the function will check if the Athena query has completed.
Returns
-------
Dictionary with the get_query_execution response.
Examples
--------
>>> import awswrangler as wr
>>> res = wr.athena.wait_query(query_execution_id='query-execution-id')
"""
response: dict[str, Any] = get_query_execution(query_execution_id=query_execution_id, boto3_session=boto3_session)
state: str = response["Status"]["State"]
while state not in _QUERY_FINAL_STATES:
time.sleep(athena_query_wait_polling_delay)
response = get_query_execution(query_execution_id=query_execution_id, boto3_session=boto3_session)
state = response["Status"]["State"]
_logger.debug("Query state: %s", state)
_logger.debug("Query state change reason: %s", response["Status"].get("StateChangeReason"))
if state == "FAILED":
raise exceptions.QueryFailed(response["Status"].get("StateChangeReason"))
if state == "CANCELLED":
raise exceptions.QueryCancelled(response["Status"].get("StateChangeReason"))
return response
def get_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> dict[str, Any]:
"""Fetch query execution details.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html#Athena.Client.get_query_execution
Parameters
----------
query_execution_id
Athena query execution ID.
boto3_session
The default boto3 session will be used if **boto3_session** receive ``None``.
Returns
-------
Dictionary with the get_query_execution response.
Examples
--------
>>> import awswrangler as wr
>>> res = wr.athena.get_query_execution(query_execution_id='query-execution-id')
"""
client_athena = _utils.client(service_name="athena", session=boto3_session)
response = _utils.try_it(
f=client_athena.get_query_execution,
ex=botocore.exceptions.ClientError,
ex_code="ThrottlingException",
max_num_tries=5,
QueryExecutionId=query_execution_id,
)
_logger.debug("Get query execution response:\n%s", response)
return cast(Dict[str, Any], response["QueryExecution"])