awswrangler/s3/_write.py (380 lines of code) (raw):
"""Amazon CSV S3 Write Module (PRIVATE)."""
from __future__ import annotations
import logging
import uuid
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, NamedTuple
import boto3
import pandas as pd
import pyarrow as pa
from awswrangler import _data_types, _utils, catalog, exceptions, typing
from awswrangler._distributed import EngineEnum
from awswrangler._utils import copy_df_shallow
from awswrangler.s3._delete import delete_objects
from awswrangler.s3._write_dataset import _to_dataset
if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client
_logger: logging.Logger = logging.getLogger(__name__)
_COMPRESSION_2_EXT: dict[str | None, str] = {
None: "",
"gzip": ".gz",
"snappy": ".snappy",
"bz2": ".bz2",
"xz": ".xz",
"zip": ".zip",
"zstd": ".zstd",
}
def _extract_dtypes_from_table_input(table_input: dict[str, Any]) -> dict[str, str]:
dtypes: dict[str, str] = {}
for col in table_input["StorageDescriptor"]["Columns"]:
dtypes[col["Name"]] = col["Type"]
if "PartitionKeys" in table_input:
for par in table_input["PartitionKeys"]:
dtypes[par["Name"]] = par["Type"]
return dtypes
def _apply_dtype(
df: pd.DataFrame, dtype: dict[str, str], catalog_table_input: dict[str, Any] | None, mode: str
) -> pd.DataFrame:
if mode in ("append", "overwrite_partitions"):
if catalog_table_input is not None:
catalog_types: dict[str, str] | None = _extract_dtypes_from_table_input(table_input=catalog_table_input)
if catalog_types is not None:
for k, v in catalog_types.items():
dtype[k] = v
df = _data_types.cast_pandas_with_athena_types(df=df, dtype=dtype)
return df
def _validate_args(
df: pd.DataFrame,
table: str | None,
database: str | None,
dataset: bool,
path: str | None,
partition_cols: list[str] | None,
bucketing_info: typing.BucketingInfoTuple | None,
mode: str | None,
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
execution_engine: Enum,
) -> None:
if df.empty is True:
_logger.warning("Empty DataFrame will be written.")
if dataset is False:
if path is None:
raise exceptions.InvalidArgumentValue("If dataset is False, the `path` argument must be passed.")
if execution_engine == EngineEnum.PYTHON and path.endswith("/"):
raise exceptions.InvalidArgumentValue(
"If <dataset=False>, the argument <path> should be a key, not a prefix."
)
if partition_cols:
raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use partition_cols.")
if bucketing_info:
raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use bucketing_info.")
if mode is not None:
raise exceptions.InvalidArgumentCombination("Please pass dataset=True to be able to use mode.")
if any(arg is not None for arg in (table, description, parameters, columns_comments, columns_parameters)):
raise exceptions.InvalidArgumentCombination(
"Please pass dataset=True to be able to use any one of these "
"arguments: database, table, description, parameters, "
"columns_comments, columns_parameters."
)
elif (database is None) != (table is None):
raise exceptions.InvalidArgumentCombination(
"Arguments database and table must be passed together. If you want to store your dataset metadata in "
"the Glue Catalog, please ensure you are passing both."
)
elif all(x is None for x in [path, database, table]):
raise exceptions.InvalidArgumentCombination(
"You must specify a `path` if dataset is True and database/table are not enabled."
)
elif bucketing_info and bucketing_info[1] <= 0:
raise exceptions.InvalidArgumentValue(
"Please pass a value greater than 1 for the number of buckets for bucketing."
)
class _SanitizeResult(NamedTuple):
frame: pd.DataFrame
dtype: dict[str, str]
partition_cols: list[str]
bucketing_info: typing.BucketingInfoTuple | None
def _sanitize(
df: pd.DataFrame,
dtype: dict[str, str],
partition_cols: list[str],
bucketing_info: typing.BucketingInfoTuple | None = None,
) -> _SanitizeResult:
df = catalog.sanitize_dataframe_columns_names(df=df)
partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols]
if bucketing_info:
bucketing_info = (
[catalog.sanitize_column_name(bucketing_col) for bucketing_col in bucketing_info[0]],
bucketing_info[1],
)
dtype = {catalog.sanitize_column_name(k): v for k, v in dtype.items()}
_utils.check_duplicated_columns(df=df)
return _SanitizeResult(df, dtype, partition_cols, bucketing_info)
def _get_chunk_file_path(file_counter: int, file_path: str) -> str:
slash_index: int = file_path.rfind("/")
dot_index: int = file_path.find(".", slash_index)
file_index: str = "_" + str(file_counter)
if dot_index == -1:
file_path = file_path + file_index
else:
file_path = file_path[:dot_index] + file_index + file_path[dot_index:]
return file_path
def _get_write_table_args(pyarrow_additional_kwargs: dict[str, Any] | None = None) -> dict[str, Any]:
write_table_args: dict[str, Any] = {}
if pyarrow_additional_kwargs and "write_table_args" in pyarrow_additional_kwargs:
write_table_args = pyarrow_additional_kwargs.pop("write_table_args")
return write_table_args
def _get_file_path(
path_root: str | None = None,
path: str | None = None,
filename_prefix: str | None = None,
compression_ext: str = "",
bucket_id: int | None = None,
extension: str = ".parquet",
) -> str:
if bucket_id is not None:
filename_prefix = f"{filename_prefix}_bucket-{bucket_id:05d}"
if path is None and path_root is not None:
file_path: str = f"{path_root}{filename_prefix}{compression_ext}{extension}"
elif path is not None and path_root is None:
file_path = path
else:
raise RuntimeError("path and path_root received at the same time.")
return file_path
class _S3WriteStrategy(ABC):
@property
@abstractmethod
def _write_to_s3_func(self) -> Callable[..., list[str]]:
pass
@abstractmethod
def _write_to_s3(
self,
df: pd.DataFrame,
schema: pa.Schema,
index: bool,
compression: str | None,
compression_ext: str,
pyarrow_additional_kwargs: dict[str, Any],
cpus: int,
dtype: dict[str, str],
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, str] | None,
use_threads: bool | int,
path: str | None = None,
path_root: str | None = None,
filename_prefix: str | None = None,
max_rows_by_file: int | None = 0,
bucketing: bool = False,
encryption_configuration: typing.ArrowEncryptionConfiguration | None = None,
) -> list[str]:
pass
@abstractmethod
def _create_glue_table(
self,
database: str,
table: str,
path: str,
columns_types: dict[str, str],
table_type: str | None,
partitions_types: dict[str, str] | None,
bucketing_info: typing.BucketingInfoTuple | None,
catalog_id: str | None,
compression: str | None,
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
boto3_session: boto3.Session | None,
catalog_table_input: dict[str, Any] | None,
) -> None:
pass
@abstractmethod
def _add_glue_partitions(
self,
database: str,
table: str,
partitions_values: dict[str, list[str]],
bucketing_info: typing.BucketingInfoTuple | None = None,
catalog_id: str | None = None,
compression: str | None = None,
boto3_session: boto3.Session | None = None,
columns_types: dict[str, str] | None = None,
partitions_parameters: dict[str, str] | None = None,
) -> None:
pass
def write( # noqa: PLR0913
self,
df: pd.DataFrame,
path: str | None,
index: bool,
compression: str | None,
pyarrow_additional_kwargs: dict[str, Any],
max_rows_by_file: int | None,
use_threads: bool | int,
boto3_session: boto3.Session | None,
s3_additional_kwargs: dict[str, Any] | None,
sanitize_columns: bool,
dataset: bool,
filename_prefix: str | None,
partition_cols: list[str] | None,
bucketing_info: typing.BucketingInfoTuple | None,
concurrent_partitioning: bool,
mode: str | None,
catalog_versioning: bool,
schema_evolution: bool,
database: str | None,
table: str | None,
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
regular_partitions: bool,
table_type: str | None,
dtype: dict[str, str] | None,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
catalog_id: str | None,
compression_ext: str,
encryption_configuration: typing.ArrowEncryptionConfiguration | None,
) -> typing._S3WriteDataReturnValue:
# Initializing defaults
partition_cols = partition_cols if partition_cols else []
dtype = dtype if dtype else {}
partitions_values: dict[str, list[str]] = {}
mode = "append" if mode is None else mode
filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
s3_client = _utils.client(service_name="s3", session=boto3_session)
# Sanitize table to respect Athena's standards
if (sanitize_columns is True) or (database is not None and table is not None):
df, dtype, partition_cols, bucketing_info = _sanitize(
df=copy_df_shallow(df),
dtype=dtype,
partition_cols=partition_cols,
bucketing_info=bucketing_info,
)
# Evaluating dtype
catalog_table_input: dict[str, Any] | None = None
if database is not None and table is not None:
catalog_table_input = catalog._get_table_input(
database=database,
table=table,
boto3_session=boto3_session,
catalog_id=catalog_id,
)
catalog_path: str | None = None
if catalog_table_input:
table_type = catalog_table_input["TableType"]
catalog_path = catalog_table_input["StorageDescriptor"]["Location"]
if path is None:
if catalog_path:
path = catalog_path
else:
raise exceptions.InvalidArgumentValue(
"Glue table does not exist in the catalog. Please pass the `path` argument to create it."
)
elif path and catalog_path:
if path.rstrip("/") != catalog_path.rstrip("/"):
raise exceptions.InvalidArgumentValue(
f"The specified path: {path}, does not match the existing Glue catalog table path: {catalog_path}"
)
df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode)
schema: pa.Schema = _data_types.pyarrow_schema_from_pandas(
df=df, index=index, ignore_cols=partition_cols, dtype=dtype
)
_logger.debug("Resolved pyarrow schema: \n%s", schema)
if dataset is False:
paths = self._write_to_s3(
df,
path=path,
filename_prefix=filename_prefix,
schema=schema,
index=index,
cpus=cpus,
compression=compression,
compression_ext=compression_ext,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
dtype=dtype,
max_rows_by_file=max_rows_by_file,
use_threads=use_threads,
encryption_configuration=encryption_configuration,
)
else:
columns_types: dict[str, str] = {}
partitions_types: dict[str, str] = {}
if (database is not None) and (table is not None):
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype
)
if schema_evolution is False:
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
create_table_args: dict[str, Any] = {
"database": database,
"table": table,
"path": path,
"columns_types": columns_types,
"table_type": table_type,
"partitions_types": partitions_types,
"bucketing_info": bucketing_info,
"compression": compression,
"description": description,
"parameters": parameters,
"columns_comments": columns_comments,
"columns_parameters": columns_parameters,
"boto3_session": boto3_session,
"mode": mode,
"catalog_versioning": catalog_versioning,
"athena_partition_projection_settings": athena_partition_projection_settings,
"catalog_id": catalog_id,
"catalog_table_input": catalog_table_input,
}
paths, partitions_values = _to_dataset(
func=self._write_to_s3_func,
concurrent_partitioning=concurrent_partitioning,
df=df,
path_root=path, # type: ignore[arg-type]
filename_prefix=filename_prefix,
index=index,
compression=compression,
compression_ext=compression_ext,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
cpus=cpus,
use_threads=use_threads,
partition_cols=partition_cols,
bucketing_info=bucketing_info,
dtype=dtype,
mode=mode,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
schema=schema,
max_rows_by_file=max_rows_by_file,
encryption_configuration=encryption_configuration,
)
if database and table:
try:
self._create_glue_table(**create_table_args)
if partitions_values and (regular_partitions is True):
self._add_glue_partitions(
database=database,
table=table,
partitions_values=partitions_values,
bucketing_info=bucketing_info,
compression=compression,
boto3_session=boto3_session,
catalog_id=catalog_id,
columns_types=columns_types,
)
except Exception:
_logger.debug("Catalog write failed, cleaning up S3 objects (len(paths): %s).", len(paths))
delete_objects(
path=paths,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
raise
return {"paths": paths, "partitions_values": partitions_values}