awswrangler/athena/_executions.py (108 lines of code) (raw):

"""Query executions Module for Amazon Athena.""" from __future__ import annotations import logging import time from typing import ( Any, Dict, cast, ) import boto3 import botocore from typing_extensions import Literal from awswrangler import _utils, exceptions, typing from awswrangler._config import apply_configs from ._cache import _CacheInfo, _check_for_cached_results from ._utils import ( _QUERY_FINAL_STATES, _QUERY_WAIT_POLLING_DELAY, _apply_formatter, _get_workgroup_config, _start_query_execution, _WorkGroupConfig, ) _logger: logging.Logger = logging.getLogger(__name__) @apply_configs def start_query_execution( sql: str, database: str | None = None, s3_output: str | None = None, workgroup: str = "primary", encryption: str | None = None, kms_key: str | None = None, params: dict[str, Any] | list[str] | None = None, paramstyle: Literal["qmark", "named"] = "named", boto3_session: boto3.Session | None = None, client_request_token: str | None = None, athena_cache_settings: typing.AthenaCacheSettings | None = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, data_source: str | None = None, wait: bool = False, ) -> str | dict[str, Any]: """Start a SQL Query against AWS Athena. Note ---- Create the default Athena bucket if it doesn't exist and s3_output is None. (E.g. s3://aws-athena-query-results-ACCOUNT-REGION/) Parameters ---------- sql SQL query. database AWS Glue/Athena database name. s3_output AWS S3 path. workgroup Athena workgroup. Primary by default. encryption None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'. kms_key For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. params Parameters that will be used for constructing the SQL query. Only named or question mark parameters are supported. The parameter style needs to be specified in the ``paramstyle`` parameter. For ``paramstyle="named"``, this value needs to be a dictionary. The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain ``:name``. The formatter will be applied client-side in this scenario. For ``paramstyle="qmark"``, this value needs to be a list of strings. The formatter will be applied server-side. The values are applied sequentially to the parameters in the query in the order in which the parameters occur. paramstyle Determines the style of ``params``. Possible values are: - ``named`` - ``qmark`` boto3_session The default boto3 session will be used if **boto3_session** receive ``None``. client_request_token A unique case-sensitive string used to ensure the request to create the query is idempotent (executes only once). If another StartQueryExecution request is received, the same response is returned and another query is not created. If a parameter has changed, for example, the QueryString , an error is returned. If you pass the same client_request_token value with different parameters the query fails with error message "Idempotent parameters do not match". Use this only with ctas_approach=False and unload_approach=False and disabled cache. athena_cache_settings Parameters of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections, max_remote_cache_entries, and max_local_cache_entries. AthenaCacheSettings is a `TypedDict`, meaning the passed parameter can be instantiated either as an instance of AthenaCacheSettings or as a regular Python dict. If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`, `keep_files` and `ctas_temp_table_name` params. If reading cached data fails for any reason, execution falls back to the usual query run path. athena_query_wait_polling_delay Interval in seconds for how often the function will check if the Athena query has completed. data_source Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. wait Indicates whether to wait for the query to finish and return a dictionary with the query execution response. Returns ------- Query execution ID if `wait` is set to `False`, dictionary with the get_query_execution response otherwise. Examples -------- Querying into the default data source (Amazon s3 - 'AwsDataCatalog') >>> import awswrangler as wr >>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...') Querying into another data source (PostgreSQL, Redshift, etc) >>> import awswrangler as wr >>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...', data_source='...') """ # Substitute query parameters if applicable sql, execution_params = _apply_formatter(sql, params, paramstyle) _logger.debug("Executing query:\n%s", sql) if not client_request_token: cache_info: _CacheInfo = _check_for_cached_results( sql=sql, boto3_session=boto3_session, workgroup=workgroup, athena_cache_settings=athena_cache_settings, ) _logger.debug("Cache info:\n%s", cache_info) if not client_request_token and cache_info.has_valid_cache and cache_info.query_execution_id is not None: query_execution_id = cache_info.query_execution_id _logger.debug("Valid cache found. Retrieving...") else: wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup) query_execution_id = _start_query_execution( sql=sql, wg_config=wg_config, database=database, data_source=data_source, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key, execution_params=execution_params, client_request_token=client_request_token, boto3_session=boto3_session, ) if wait: return wait_query( query_execution_id=query_execution_id, boto3_session=boto3_session, athena_query_wait_polling_delay=athena_query_wait_polling_delay, ) return query_execution_id def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None: """Stop a query execution. Requires you to have access to the workgroup in which the query ran. Parameters ---------- query_execution_id Athena query execution ID. boto3_session The default boto3 session will be used if **boto3_session** receive ``None``. Examples -------- >>> import awswrangler as wr >>> wr.athena.stop_query_execution(query_execution_id='query-execution-id') """ client_athena = _utils.client(service_name="athena", session=boto3_session) client_athena.stop_query_execution(QueryExecutionId=query_execution_id) @apply_configs def wait_query( query_execution_id: str, boto3_session: boto3.Session | None = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, ) -> dict[str, Any]: """Wait for the query end. Parameters ---------- query_execution_id Athena query execution ID. boto3_session The default boto3 session will be used if **boto3_session** receive ``None``. athena_query_wait_polling_delay Interval in seconds for how often the function will check if the Athena query has completed. Returns ------- Dictionary with the get_query_execution response. Examples -------- >>> import awswrangler as wr >>> res = wr.athena.wait_query(query_execution_id='query-execution-id') """ response: dict[str, Any] = get_query_execution(query_execution_id=query_execution_id, boto3_session=boto3_session) state: str = response["Status"]["State"] while state not in _QUERY_FINAL_STATES: time.sleep(athena_query_wait_polling_delay) response = get_query_execution(query_execution_id=query_execution_id, boto3_session=boto3_session) state = response["Status"]["State"] _logger.debug("Query state: %s", state) _logger.debug("Query state change reason: %s", response["Status"].get("StateChangeReason")) if state == "FAILED": raise exceptions.QueryFailed(response["Status"].get("StateChangeReason")) if state == "CANCELLED": raise exceptions.QueryCancelled(response["Status"].get("StateChangeReason")) return response def get_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> dict[str, Any]: """Fetch query execution details. https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html#Athena.Client.get_query_execution Parameters ---------- query_execution_id Athena query execution ID. boto3_session The default boto3 session will be used if **boto3_session** receive ``None``. Returns ------- Dictionary with the get_query_execution response. Examples -------- >>> import awswrangler as wr >>> res = wr.athena.get_query_execution(query_execution_id='query-execution-id') """ client_athena = _utils.client(service_name="athena", session=boto3_session) response = _utils.try_it( f=client_athena.get_query_execution, ex=botocore.exceptions.ClientError, ex_code="ThrottlingException", max_num_tries=5, QueryExecutionId=query_execution_id, ) _logger.debug("Get query execution response:\n%s", response) return cast(Dict[str, Any], response["QueryExecution"])