awswrangler/athena/_cache.py (169 lines of code) (raw):
"""Cache Module for Amazon Athena."""
from __future__ import annotations
import datetime
import logging
import re
import threading
from heapq import heappop, heappush
from typing import TYPE_CHECKING, Match, NamedTuple
import boto3
from awswrangler import _utils, typing
if TYPE_CHECKING:
from mypy_boto3_athena.type_defs import QueryExecutionTypeDef
_logger: logging.Logger = logging.getLogger(__name__)
class _CacheInfo(NamedTuple):
has_valid_cache: bool
file_format: str | None = None
query_execution_id: str | None = None
query_execution_payload: "QueryExecutionTypeDef" | None = None
class _LocalMetadataCacheManager:
def __init__(self) -> None:
self._lock: threading.Lock = threading.Lock()
self._cache: dict[str, "QueryExecutionTypeDef"] = {}
self._pqueue: list[tuple[datetime.datetime, str]] = []
self._max_cache_size = 100
def update_cache(self, items: list["QueryExecutionTypeDef"]) -> None:
"""
Update the local metadata cache with new query metadata.
Parameters
----------
items
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
"""
with self._lock:
if self._pqueue:
oldest_item = self._cache.get(self._pqueue[0][1])
if oldest_item:
items = list(
filter(
lambda x: x["Status"]["SubmissionDateTime"] > oldest_item["Status"]["SubmissionDateTime"],
items,
)
)
cache_oversize = len(self._cache) + len(items) - self._max_cache_size
for _ in range(cache_oversize):
_, query_execution_id = heappop(self._pqueue)
self._cache.pop(query_execution_id, None)
for item in items[: self._max_cache_size]:
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
self._cache[item["QueryExecutionId"]] = item
def sorted_successful_generator(self) -> list["QueryExecutionTypeDef"]:
"""
Sorts the entries in the local cache based on query Completion DateTime.
This is useful to guarantee LRU caching rules.
Returns
-------
Returns successful DDL and DML queries sorted by query completion time.
"""
filtered: list["QueryExecutionTypeDef"] = []
with self._lock:
for query in self._cache.values():
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
filtered.append(query)
return sorted(filtered, key=lambda e: str(e["Status"]["CompletionDateTime"]), reverse=True)
def __contains__(self, key: str) -> bool:
return key in self._cache
@property
def max_cache_size(self) -> int:
"""Property max_cache_size."""
return self._max_cache_size
@max_cache_size.setter
def max_cache_size(self, value: int) -> None:
self._max_cache_size = value
def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None:
"""Check if `possible_ctas` is a valid parquet-generating CTAS and returns the full SELECT statement."""
possible_ctas = possible_ctas.lower()
parquet_format_regex: str = r"format\s*=\s*\'parquet\'\s*"
is_parquet_format: Match[str] | None = re.search(pattern=parquet_format_regex, string=possible_ctas)
if is_parquet_format is not None:
unstripped_select_statement_regex: str = r"\s+as\s+\(*(select|with).*"
unstripped_select_statement_match: Match[str] | None = re.search(
unstripped_select_statement_regex, possible_ctas, re.DOTALL
)
if unstripped_select_statement_match is not None:
stripped_select_statement_match: Match[str] | None = re.search(
r"(select|with).*", unstripped_select_statement_match.group(0), re.DOTALL
)
if stripped_select_statement_match is not None:
return stripped_select_statement_match.group(0)
return None
def _compare_query_string(sql: str, other: str) -> bool:
comparison_query = _prepare_query_string_for_comparison(query_string=other)
_logger.debug("sql: %s", sql)
_logger.debug("comparison_query: %s", comparison_query)
return sql == comparison_query
def _prepare_query_string_for_comparison(query_string: str) -> str:
"""To use cached data, we need to compare queries. Returns a query string in canonical form."""
# for now this is a simple complete strip, but it could grow into much more sophisticated
# query comparison data structures
query_string = "".join(query_string.split()).strip().lower()
while query_string.startswith("(") and query_string.endswith(")"):
query_string = query_string[1:-1]
query_string = query_string[:-1] if query_string.endswith(";") else query_string
return query_string
def _get_last_query_infos(
max_remote_cache_entries: int,
boto3_session: boto3.Session | None = None,
workgroup: str | None = None,
) -> list["QueryExecutionTypeDef"]:
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
client_athena = _utils.client(service_name="athena", session=boto3_session)
page_size = 50
args: dict[str, str | dict[str, int]] = {
"PaginationConfig": {"MaxItems": max_remote_cache_entries, "PageSize": page_size}
}
if workgroup is not None:
args["WorkGroup"] = workgroup
paginator = client_athena.get_paginator("list_query_executions")
uncached_ids: list[str] = []
for page in paginator.paginate(**args): # type: ignore[arg-type]
_logger.debug("paginating Athena's queries history...")
query_execution_id_list: list[str] = page["QueryExecutionIds"]
for query_execution_id in query_execution_id_list:
if query_execution_id not in _cache_manager:
uncached_ids.append(query_execution_id)
if uncached_ids:
new_execution_data: list[QueryExecutionTypeDef] = []
for i in range(0, len(uncached_ids), page_size):
new_execution_data.extend(
client_athena.batch_get_query_execution( # type: ignore[arg-type]
QueryExecutionIds=uncached_ids[i : i + page_size],
).get("QueryExecutions")
)
_cache_manager.update_cache(new_execution_data)
return _cache_manager.sorted_successful_generator()
def _check_for_cached_results(
sql: str,
boto3_session: boto3.Session | None,
workgroup: str | None,
athena_cache_settings: typing.AthenaCacheSettings | None = None,
) -> _CacheInfo:
"""
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
If so, returns a dict with Athena's `query_execution_info` and the data format.
"""
athena_cache_settings = athena_cache_settings or {}
max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0)
max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50)
max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50)
max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100)
max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries)
_cache_manager.max_cache_size = max_local_cache_entries
if max_cache_seconds <= 0:
return _CacheInfo(has_valid_cache=False)
num_executions_inspected: int = 0
comparable_sql: str = _prepare_query_string_for_comparison(sql)
current_timestamp: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
_logger.debug("current_timestamp: %s", current_timestamp)
for query_info in _get_last_query_infos(
max_remote_cache_entries=max_remote_cache_entries,
boto3_session=boto3_session,
workgroup=workgroup,
):
query_execution_id: str = query_info["QueryExecutionId"]
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
_logger.debug("query_timestamp: %s", query_timestamp)
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
return _CacheInfo(
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
)
statement_type: str | None = query_info.get("StatementType")
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(sql=comparable_sql, other=parsed_query):
return _CacheInfo(
has_valid_cache=True,
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
num_executions_inspected += 1
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
if num_executions_inspected >= max_cache_query_inspections:
return _CacheInfo(has_valid_cache=False)
return _CacheInfo(has_valid_cache=False)
_cache_manager = _LocalMetadataCacheManager()