awswrangler/s3/_write.py (380 lines of code) (raw):

"""Amazon CSV S3 Write Module (PRIVATE).""" from __future__ import annotations import logging import uuid from abc import ABC, abstractmethod from enum import Enum from typing import TYPE_CHECKING, Any, Callable, NamedTuple import boto3 import pandas as pd import pyarrow as pa from awswrangler import _data_types, _utils, catalog, exceptions, typing from awswrangler._distributed import EngineEnum from awswrangler._utils import copy_df_shallow from awswrangler.s3._delete import delete_objects from awswrangler.s3._write_dataset import _to_dataset if TYPE_CHECKING: from mypy_boto3_s3 import S3Client _logger: logging.Logger = logging.getLogger(__name__) _COMPRESSION_2_EXT: dict[str | None, str] = { None: "", "gzip": ".gz", "snappy": ".snappy", "bz2": ".bz2", "xz": ".xz", "zip": ".zip", "zstd": ".zstd", } def _extract_dtypes_from_table_input(table_input: dict[str, Any]) -> dict[str, str]: dtypes: dict[str, str] = {} for col in table_input["StorageDescriptor"]["Columns"]: dtypes[col["Name"]] = col["Type"] if "PartitionKeys" in table_input: for par in table_input["PartitionKeys"]: dtypes[par["Name"]] = par["Type"] return dtypes def _apply_dtype( df: pd.DataFrame, dtype: dict[str, str], catalog_table_input: dict[str, Any] | None, mode: str ) -> pd.DataFrame: if mode in ("append", "overwrite_partitions"): if catalog_table_input is not None: catalog_types: dict[str, str] | None = _extract_dtypes_from_table_input(table_input=catalog_table_input) if catalog_types is not None: for k, v in catalog_types.items(): dtype[k] = v df = _data_types.cast_pandas_with_athena_types(df=df, dtype=dtype) return df def _validate_args( df: pd.DataFrame, table: str | None, database: str | None, dataset: bool, path: str | None, partition_cols: list[str] | None, bucketing_info: typing.BucketingInfoTuple | None, mode: str | None, description: str | None, parameters: dict[str, str] | None, columns_comments: dict[str, str] | None, columns_parameters: dict[str, dict[str, str]] | None, execution_engine: Enum, ) -> None: if df.empty is True: _logger.warning("Empty DataFrame will be written.") if dataset is False: if path is None: raise exceptions.InvalidArgumentValue("If dataset is False, the `path` argument must be passed.") if execution_engine == EngineEnum.PYTHON and path.endswith("/"): raise exceptions.InvalidArgumentValue( "If <dataset=False>, the argument <path> should be a key, not a prefix." ) if partition_cols: raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use partition_cols.") if bucketing_info: raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use bucketing_info.") if mode is not None: raise exceptions.InvalidArgumentCombination("Please pass dataset=True to be able to use mode.") if any(arg is not None for arg in (table, description, parameters, columns_comments, columns_parameters)): raise exceptions.InvalidArgumentCombination( "Please pass dataset=True to be able to use any one of these " "arguments: database, table, description, parameters, " "columns_comments, columns_parameters." ) elif (database is None) != (table is None): raise exceptions.InvalidArgumentCombination( "Arguments database and table must be passed together. If you want to store your dataset metadata in " "the Glue Catalog, please ensure you are passing both." ) elif all(x is None for x in [path, database, table]): raise exceptions.InvalidArgumentCombination( "You must specify a `path` if dataset is True and database/table are not enabled." ) elif bucketing_info and bucketing_info[1] <= 0: raise exceptions.InvalidArgumentValue( "Please pass a value greater than 1 for the number of buckets for bucketing." ) class _SanitizeResult(NamedTuple): frame: pd.DataFrame dtype: dict[str, str] partition_cols: list[str] bucketing_info: typing.BucketingInfoTuple | None def _sanitize( df: pd.DataFrame, dtype: dict[str, str], partition_cols: list[str], bucketing_info: typing.BucketingInfoTuple | None = None, ) -> _SanitizeResult: df = catalog.sanitize_dataframe_columns_names(df=df) partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols] if bucketing_info: bucketing_info = ( [catalog.sanitize_column_name(bucketing_col) for bucketing_col in bucketing_info[0]], bucketing_info[1], ) dtype = {catalog.sanitize_column_name(k): v for k, v in dtype.items()} _utils.check_duplicated_columns(df=df) return _SanitizeResult(df, dtype, partition_cols, bucketing_info) def _get_chunk_file_path(file_counter: int, file_path: str) -> str: slash_index: int = file_path.rfind("/") dot_index: int = file_path.find(".", slash_index) file_index: str = "_" + str(file_counter) if dot_index == -1: file_path = file_path + file_index else: file_path = file_path[:dot_index] + file_index + file_path[dot_index:] return file_path def _get_write_table_args(pyarrow_additional_kwargs: dict[str, Any] | None = None) -> dict[str, Any]: write_table_args: dict[str, Any] = {} if pyarrow_additional_kwargs and "write_table_args" in pyarrow_additional_kwargs: write_table_args = pyarrow_additional_kwargs.pop("write_table_args") return write_table_args def _get_file_path( path_root: str | None = None, path: str | None = None, filename_prefix: str | None = None, compression_ext: str = "", bucket_id: int | None = None, extension: str = ".parquet", ) -> str: if bucket_id is not None: filename_prefix = f"{filename_prefix}_bucket-{bucket_id:05d}" if path is None and path_root is not None: file_path: str = f"{path_root}{filename_prefix}{compression_ext}{extension}" elif path is not None and path_root is None: file_path = path else: raise RuntimeError("path and path_root received at the same time.") return file_path class _S3WriteStrategy(ABC): @property @abstractmethod def _write_to_s3_func(self) -> Callable[..., list[str]]: pass @abstractmethod def _write_to_s3( self, df: pd.DataFrame, schema: pa.Schema, index: bool, compression: str | None, compression_ext: str, pyarrow_additional_kwargs: dict[str, Any], cpus: int, dtype: dict[str, str], s3_client: "S3Client" | None, s3_additional_kwargs: dict[str, str] | None, use_threads: bool | int, path: str | None = None, path_root: str | None = None, filename_prefix: str | None = None, max_rows_by_file: int | None = 0, bucketing: bool = False, encryption_configuration: typing.ArrowEncryptionConfiguration | None = None, ) -> list[str]: pass @abstractmethod def _create_glue_table( self, database: str, table: str, path: str, columns_types: dict[str, str], table_type: str | None, partitions_types: dict[str, str] | None, bucketing_info: typing.BucketingInfoTuple | None, catalog_id: str | None, compression: str | None, description: str | None, parameters: dict[str, str] | None, columns_comments: dict[str, str] | None, columns_parameters: dict[str, dict[str, str]] | None, mode: str, catalog_versioning: bool, athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None, boto3_session: boto3.Session | None, catalog_table_input: dict[str, Any] | None, ) -> None: pass @abstractmethod def _add_glue_partitions( self, database: str, table: str, partitions_values: dict[str, list[str]], bucketing_info: typing.BucketingInfoTuple | None = None, catalog_id: str | None = None, compression: str | None = None, boto3_session: boto3.Session | None = None, columns_types: dict[str, str] | None = None, partitions_parameters: dict[str, str] | None = None, ) -> None: pass def write( # noqa: PLR0913 self, df: pd.DataFrame, path: str | None, index: bool, compression: str | None, pyarrow_additional_kwargs: dict[str, Any], max_rows_by_file: int | None, use_threads: bool | int, boto3_session: boto3.Session | None, s3_additional_kwargs: dict[str, Any] | None, sanitize_columns: bool, dataset: bool, filename_prefix: str | None, partition_cols: list[str] | None, bucketing_info: typing.BucketingInfoTuple | None, concurrent_partitioning: bool, mode: str | None, catalog_versioning: bool, schema_evolution: bool, database: str | None, table: str | None, description: str | None, parameters: dict[str, str] | None, columns_comments: dict[str, str] | None, columns_parameters: dict[str, dict[str, str]] | None, regular_partitions: bool, table_type: str | None, dtype: dict[str, str] | None, athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None, catalog_id: str | None, compression_ext: str, encryption_configuration: typing.ArrowEncryptionConfiguration | None, ) -> typing._S3WriteDataReturnValue: # Initializing defaults partition_cols = partition_cols if partition_cols else [] dtype = dtype if dtype else {} partitions_values: dict[str, list[str]] = {} mode = "append" if mode is None else mode filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) s3_client = _utils.client(service_name="s3", session=boto3_session) # Sanitize table to respect Athena's standards if (sanitize_columns is True) or (database is not None and table is not None): df, dtype, partition_cols, bucketing_info = _sanitize( df=copy_df_shallow(df), dtype=dtype, partition_cols=partition_cols, bucketing_info=bucketing_info, ) # Evaluating dtype catalog_table_input: dict[str, Any] | None = None if database is not None and table is not None: catalog_table_input = catalog._get_table_input( database=database, table=table, boto3_session=boto3_session, catalog_id=catalog_id, ) catalog_path: str | None = None if catalog_table_input: table_type = catalog_table_input["TableType"] catalog_path = catalog_table_input["StorageDescriptor"]["Location"] if path is None: if catalog_path: path = catalog_path else: raise exceptions.InvalidArgumentValue( "Glue table does not exist in the catalog. Please pass the `path` argument to create it." ) elif path and catalog_path: if path.rstrip("/") != catalog_path.rstrip("/"): raise exceptions.InvalidArgumentValue( f"The specified path: {path}, does not match the existing Glue catalog table path: {catalog_path}" ) df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode) schema: pa.Schema = _data_types.pyarrow_schema_from_pandas( df=df, index=index, ignore_cols=partition_cols, dtype=dtype ) _logger.debug("Resolved pyarrow schema: \n%s", schema) if dataset is False: paths = self._write_to_s3( df, path=path, filename_prefix=filename_prefix, schema=schema, index=index, cpus=cpus, compression=compression, compression_ext=compression_ext, pyarrow_additional_kwargs=pyarrow_additional_kwargs, s3_client=s3_client, s3_additional_kwargs=s3_additional_kwargs, dtype=dtype, max_rows_by_file=max_rows_by_file, use_threads=use_threads, encryption_configuration=encryption_configuration, ) else: columns_types: dict[str, str] = {} partitions_types: dict[str, str] = {} if (database is not None) and (table is not None): columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( df=df, index=index, partition_cols=partition_cols, dtype=dtype ) if schema_evolution is False: _utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode) create_table_args: dict[str, Any] = { "database": database, "table": table, "path": path, "columns_types": columns_types, "table_type": table_type, "partitions_types": partitions_types, "bucketing_info": bucketing_info, "compression": compression, "description": description, "parameters": parameters, "columns_comments": columns_comments, "columns_parameters": columns_parameters, "boto3_session": boto3_session, "mode": mode, "catalog_versioning": catalog_versioning, "athena_partition_projection_settings": athena_partition_projection_settings, "catalog_id": catalog_id, "catalog_table_input": catalog_table_input, } paths, partitions_values = _to_dataset( func=self._write_to_s3_func, concurrent_partitioning=concurrent_partitioning, df=df, path_root=path, # type: ignore[arg-type] filename_prefix=filename_prefix, index=index, compression=compression, compression_ext=compression_ext, pyarrow_additional_kwargs=pyarrow_additional_kwargs, cpus=cpus, use_threads=use_threads, partition_cols=partition_cols, bucketing_info=bucketing_info, dtype=dtype, mode=mode, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, schema=schema, max_rows_by_file=max_rows_by_file, encryption_configuration=encryption_configuration, ) if database and table: try: self._create_glue_table(**create_table_args) if partitions_values and (regular_partitions is True): self._add_glue_partitions( database=database, table=table, partitions_values=partitions_values, bucketing_info=bucketing_info, compression=compression, boto3_session=boto3_session, catalog_id=catalog_id, columns_types=columns_types, ) except Exception: _logger.debug("Catalog write failed, cleaning up S3 objects (len(paths): %s).", len(paths)) delete_objects( path=paths, use_threads=use_threads, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, ) raise return {"paths": paths, "partitions_values": partitions_values}