awswrangler/s3/_select.py (190 lines of code) (raw):

"""Amazon S3 Select Module (PRIVATE).""" from __future__ import annotations import datetime import itertools import json import logging import pprint from typing import TYPE_CHECKING, Any, Iterator import boto3 import pandas as pd import pyarrow as pa from typing_extensions import Literal from awswrangler import _data_types, _utils, exceptions from awswrangler._distributed import engine from awswrangler._executor import _BaseExecutor, _get_executor from awswrangler.annotations import Deprecated from awswrangler.distributed.ray import ray_get from awswrangler.s3._describe import size_objects from awswrangler.s3._list import _path2list from awswrangler.s3._read import _get_path_ignore_suffix if TYPE_CHECKING: from mypy_boto3_s3 import S3Client _logger: logging.Logger = logging.getLogger(__name__) _RANGE_CHUNK_SIZE: int = int(1024 * 1024) def _gen_scan_range(obj_size: int, scan_range_chunk_size: int | None = None) -> Iterator[tuple[int, int]]: chunk_size = scan_range_chunk_size or _RANGE_CHUNK_SIZE for i in range(0, obj_size, chunk_size): yield (i, i + min(chunk_size, obj_size - i)) @engine.dispatch_on_engine @_utils.retry( ex=exceptions.S3SelectRequestIncomplete, ) def _select_object_content( s3_client: "S3Client" | None, args: dict[str, Any], scan_range: tuple[int, int] | None = None, schema: pa.Schema | None = None, ) -> pa.Table: client_s3: "S3Client" = s3_client if s3_client else _utils.client(service_name="s3") if scan_range: response = client_s3.select_object_content(**args, ScanRange={"Start": scan_range[0], "End": scan_range[1]}) else: response = client_s3.select_object_content(**args) payload_records = [] partial_record: str = "" request_complete: bool = False for event in response["Payload"]: if "Records" in event: records = ( event["Records"]["Payload"] .decode( encoding="utf-8", errors="ignore", ) .split("\n") ) records[0] = partial_record + records[0] # Record end can either be a partial record or a return char partial_record = records.pop() payload_records.extend([json.loads(record) for record in records]) elif "End" in event: # End Event signals the request was successful _logger.debug( "Received End Event. Result is complete for S3 key: %s, Scan Range: %s", args["Key"], scan_range if scan_range else 0, ) request_complete = True # If the End Event is not received, the results may be incomplete if not request_complete: raise exceptions.S3SelectRequestIncomplete( f"S3 Select request for path {args['Key']} is incomplete as End Event was not received" ) return _utils.list_to_arrow_table(mapping=payload_records, schema=schema) @engine.dispatch_on_engine def _select_query( path: str, executor: _BaseExecutor, sql: str, input_serialization: str, input_serialization_params: dict[str, bool | str], schema: pa.Schema | None = None, compression: str | None = None, scan_range_chunk_size: int | None = None, boto3_session: boto3.Session | None = None, s3_additional_kwargs: dict[str, Any] | None = None, ) -> list[pa.Table]: bucket, key = _utils.parse_path(path) s3_client = _utils.client(service_name="s3", session=boto3_session) args: dict[str, Any] = { "Bucket": bucket, "Key": key, "Expression": sql, "ExpressionType": "SQL", "RequestProgress": {"Enabled": False}, "InputSerialization": { input_serialization: input_serialization_params, "CompressionType": compression.upper() if compression else "NONE", }, "OutputSerialization": { "JSON": {}, }, } if s3_additional_kwargs: args.update(s3_additional_kwargs) _logger.debug("args:\n%s", pprint.pformat(args)) obj_size: int = size_objects( # type: ignore[assignment] path=[path], use_threads=False, boto3_session=boto3_session, ).get(path) if obj_size is None: raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}") scan_ranges: Iterator[tuple[int, int] | None] = _gen_scan_range( obj_size=obj_size, scan_range_chunk_size=scan_range_chunk_size ) if any( [ compression, input_serialization_params.get("AllowQuotedRecordDelimiter"), input_serialization_params.get("Type") == "Document", ] ): # Scan range is only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) # and JSON objects (in LINES mode only) scan_ranges = [None] # type: ignore[assignment] return executor.map( _select_object_content, s3_client, itertools.repeat(args), scan_ranges, itertools.repeat(schema), ) @Deprecated @_utils.validate_distributed_kwargs( unsupported_kwargs=["boto3_session"], ) def select_query( sql: str, path: str | list[str], input_serialization: str, input_serialization_params: dict[str, bool | str], compression: str | None = None, scan_range_chunk_size: int | None = None, path_suffix: str | list[str] | None = None, path_ignore_suffix: str | list[str] | None = None, ignore_empty: bool = True, use_threads: bool | int = True, last_modified_begin: datetime.datetime | None = None, last_modified_end: datetime.datetime | None = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", boto3_session: boto3.Session | None = None, s3_additional_kwargs: dict[str, Any] | None = None, pyarrow_additional_kwargs: dict[str, Any] | None = None, ) -> pd.DataFrame: r"""Filter contents of Amazon S3 objects based on SQL statement. Note: Scan ranges are only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) and JSON objects (in LINES mode only). It means scanning cannot be split across threads if the aforementioned conditions are not met, leading to lower performance. Parameters ---------- sql SQL statement used to query the object. path S3 prefix (accepts Unix shell-style wildcards) (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. ``[s3://bucket/key0, s3://bucket/key1]``). input_serialization Format of the S3 object queried. Valid values: "CSV", "JSON", or "Parquet". Case sensitive. input_serialization_params Dictionary describing the serialization of the S3 object. compression Compression type of the S3 object. Valid values: None, "gzip", or "bzip2". gzip and bzip2 are only valid for CSV and JSON objects. scan_range_chunk_size Chunk size used to split the S3 object into scan ranges. 1,048,576 by default. path_suffix Suffix or List of suffixes to be read (e.g. [".csv"]). If None, read all files. (default) path_ignore_suffix Suffix or List of suffixes for S3 keys to be ignored. (e.g. ["_SUCCESS"]). If None, read all files. (default) ignore_empty Ignore files with 0 bytes. use_threads True (default) to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() is used as the max number of threads. If integer is provided, specified number is used. last_modified_begin Filter S3 objects by Last modified date. Filter is only applied after listing all objects. last_modified_end Filter S3 objects by Last modified date. Filter is only applied after listing all objects. dtype_backend Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays, nullable dtypes are used for all dtypes that have a nullable implementation when “numpy_nullable” is set, pyarrow is used for all dtypes if “pyarrow” is set. The dtype_backends are still experimential. The "pyarrow" backend is only supported with Pandas 2.0 or above. boto3_session The default boto3 session is used if none is provided. s3_additional_kwargs Forwarded to botocore requests. Valid values: "SSECustomerAlgorithm", "SSECustomerKey", "ExpectedBucketOwner". e.g. s3_additional_kwargs={'SSECustomerAlgorithm': 'md5'}. pyarrow_additional_kwargs Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame. Valid values include "split_blocks", "self_destruct", "ignore_metadata". e.g. pyarrow_additional_kwargs={'split_blocks': True}. Returns ------- Pandas DataFrame with results from query. Examples -------- Reading a gzip compressed JSON document >>> import awswrangler as wr >>> df = wr.s3.select_query( ... sql='SELECT * FROM s3object[*][*]', ... path='s3://bucket/key.json.gzip', ... input_serialization='JSON', ... input_serialization_params={ ... 'Type': 'Document', ... }, ... compression="gzip", ... ) Reading multiple CSV objects from a prefix >>> import awswrangler as wr >>> df = wr.s3.select_query( ... sql='SELECT * FROM s3object', ... path='s3://bucket/prefix/', ... input_serialization='CSV', ... input_serialization_params={ ... 'FileHeaderInfo': 'Use', ... 'RecordDelimiter': '\r\n' ... }, ... ) Reading a single column from Parquet object with pushdown filter >>> import awswrangler as wr >>> df = wr.s3.select_query( ... sql='SELECT s.\"id\" FROM s3object s where s.\"id\" = 1.0', ... path='s3://bucket/key.snappy.parquet', ... input_serialization='Parquet', ... ) """ if input_serialization not in ["CSV", "JSON", "Parquet"]: raise exceptions.InvalidArgumentValue("<input_serialization> argument must be 'CSV', 'JSON' or 'Parquet'") if compression not in [None, "gzip", "bzip2"]: raise exceptions.InvalidCompression(f"Invalid {compression} compression, please use None, 'gzip' or 'bzip2'.") if compression and (input_serialization not in ["CSV", "JSON"]): raise exceptions.InvalidArgumentCombination( "'gzip' or 'bzip2' are only valid for input 'CSV' or 'JSON' objects." ) s3_client = _utils.client(service_name="s3", session=boto3_session) paths: list[str] = _path2list( path=path, s3_client=s3_client, suffix=path_suffix, ignore_suffix=_get_path_ignore_suffix(path_ignore_suffix=path_ignore_suffix), ignore_empty=ignore_empty, last_modified_begin=last_modified_begin, last_modified_end=last_modified_end, s3_additional_kwargs=s3_additional_kwargs, ) if len(paths) < 1: raise exceptions.NoFilesFound(f"No files Found: {path}.") select_kwargs: dict[str, Any] = { "sql": sql, "input_serialization": input_serialization, "input_serialization_params": input_serialization_params, "compression": compression, "scan_range_chunk_size": scan_range_chunk_size, "boto3_session": boto3_session, "s3_additional_kwargs": s3_additional_kwargs, } if pyarrow_additional_kwargs and "schema" in pyarrow_additional_kwargs: select_kwargs["schema"] = pyarrow_additional_kwargs.pop("schema") arrow_kwargs = _data_types.pyarrow2pandas_defaults( use_threads=use_threads, kwargs=pyarrow_additional_kwargs, dtype_backend=dtype_backend ) executor: _BaseExecutor = _get_executor(use_threads=use_threads) tables = list( itertools.chain(*ray_get([_select_query(path=path, executor=executor, **select_kwargs) for path in paths])) ) return _utils.table_refs_to_df(tables, kwargs=arrow_kwargs)