awswrangler/cleanrooms/_read.py (79 lines of code) (raw):

"""Amazon Clean Rooms Module hosting read_* functions.""" from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Iterator import boto3 import awswrangler.pandas as pd from awswrangler import _utils, exceptions, s3 from awswrangler._sql_formatter import _process_sql_params from awswrangler.cleanrooms._utils import wait_query if TYPE_CHECKING: from mypy_boto3_cleanrooms.type_defs import ProtectedQuerySQLParametersTypeDef _logger: logging.Logger = logging.getLogger(__name__) def _delete_after_iterate( dfs: Iterator[pd.DataFrame], keep_files: bool, kwargs: dict[str, Any] ) -> Iterator[pd.DataFrame]: yield from dfs if keep_files is False: s3.delete_objects(**kwargs) def read_sql_query( sql: str | None = None, analysis_template_arn: str | None = None, membership_id: str = "", output_bucket: str = "", output_prefix: str = "", keep_files: bool = True, params: dict[str, Any] | None = None, chunksize: int | bool | None = None, use_threads: bool | int = True, boto3_session: boto3.Session | None = None, pyarrow_additional_kwargs: dict[str, Any] | None = None, ) -> Iterator[pd.DataFrame] | pd.DataFrame: """Execute Clean Rooms Protected SQL query and return the results as a Pandas DataFrame. Note ---- One of `sql` or `analysis_template_arn` must be supplied, not both. Parameters ---------- sql SQL query analysis_template_arn ARN of the analysis template membership_id Membership ID output_bucket S3 output bucket name output_prefix S3 output prefix keep_files Whether files in S3 output bucket/prefix are retained. 'True' by default params (Client-side) If used in combination with the `sql` parameter, it's the Dict of parameters used for constructing the SQL query. Only named parameters are supported. The dict must be in the form {'name': 'value'} and the SQL query must contain `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. (Server-side) If used in combination with the `analysis_template_arn` parameter, it's the Dict of parameters supplied with the analysis template. It must be a string to string dict in the form {'name': 'value'}. chunksize If passed, the data is split into an iterable of DataFrames (Memory friendly). If `True` an iterable of DataFrames is returned without guarantee of chunksize. If an `INTEGER` is passed, an iterable of DataFrames is returned with maximum rows equal to the received INTEGER use_threads True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() is used as the maximum number of threads. If integer is provided, specified number is used boto3_session The default boto3 session will be used if **boto3_session** is ``None``. pyarrow_additional_kwargs Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame. Valid values include "split_blocks", "self_destruct", "ignore_metadata". e.g. pyarrow_additional_kwargs={'split_blocks': True} Returns ------- Pandas DataFrame or Generator of Pandas DataFrames if chunksize is provided. Examples -------- >>> import awswrangler as wr >>> df = wr.cleanrooms.read_sql_query( ... sql='SELECT DISTINCT...', ... membership_id='membership-id', ... output_bucket='output-bucket', ... output_prefix='output-prefix', ... ) >>> import awswrangler as wr >>> df = wr.cleanrooms.read_sql_query( ... analysis_template_arn='arn:aws:cleanrooms:...', ... params={'param1': 'value1'}, ... membership_id='membership-id', ... output_bucket='output-bucket', ... output_prefix='output-prefix', ... ) """ client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session) if sql: sql_parameters: "ProtectedQuerySQLParametersTypeDef" = { "queryString": _process_sql_params(sql, params, engine_type="partiql") } elif analysis_template_arn: sql_parameters = {"analysisTemplateArn": analysis_template_arn} if params: sql_parameters["parameters"] = params else: raise exceptions.InvalidArgumentCombination("One of `sql` or `analysis_template_arn` must be supplied") query_id: str = client_cleanrooms.start_protected_query( type="SQL", membershipIdentifier=membership_id, sqlParameters=sql_parameters, resultConfiguration={ "outputConfiguration": { "s3": { "bucket": output_bucket, "keyPrefix": output_prefix, "resultFormat": "PARQUET", } } }, )["protectedQuery"]["id"] _logger.debug("query_id: %s", query_id) path: str = wait_query(membership_id=membership_id, query_id=query_id, boto3_session=boto3_session)[ "protectedQuery" ]["result"]["output"]["s3"]["location"] _logger.debug("path: %s", path) chunked: bool | int = False if chunksize is None else chunksize ret = s3.read_parquet( path=path, use_threads=use_threads, chunked=chunked, boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, ) _logger.debug("type(ret): %s", type(ret)) kwargs: dict[str, Any] = { "path": path, "use_threads": use_threads, "boto3_session": boto3_session, } if chunked is False: if keep_files is False: s3.delete_objects(**kwargs) return ret return _delete_after_iterate(ret, keep_files, kwargs)