awswrangler/s3/_write_dataset.py (231 lines of code) (raw):
"""Amazon S3 Write Dataset (PRIVATE)."""
from __future__ import annotations
import logging
from typing import Any, Callable
import boto3
import numpy as np
import pandas as pd
from awswrangler import exceptions, typing
from awswrangler._distributed import engine
from awswrangler._utils import client
from awswrangler.s3._delete import delete_objects
from awswrangler.s3._write_concurrent import _WriteProxy
_logger: logging.Logger = logging.getLogger(__name__)
def _get_bucketing_series(df: pd.DataFrame, bucketing_info: typing.BucketingInfoTuple) -> pd.Series:
bucket_number_series = (
df[bucketing_info[0]]
# Prevent "upcasting" mixed types by casting to object
.astype("O")
.apply(
lambda row: _get_bucket_number(bucketing_info[1], [row[col_name] for col_name in bucketing_info[0]]),
axis="columns",
)
)
return bucket_number_series.astype(np.array([pd.CategoricalDtype(range(bucketing_info[1]))]))
def _simulate_overflow(value: int, bits: int = 31, signed: bool = False) -> int:
base = 1 << bits
value %= base
return value - base if signed and value.bit_length() == bits else value
def _get_bucket_number(number_of_buckets: int, values: list[str | int | bool]) -> int:
hash_code = 0
for value in values:
hash_code = 31 * hash_code + _get_value_hash(value)
hash_code = _simulate_overflow(hash_code)
return hash_code % number_of_buckets
def _get_value_hash(value: str | int | bool) -> int:
if isinstance(value, (int, np.int_)):
value = int(value)
bigint_min, bigint_max = -(2**63), 2**63 - 1
int_min, int_max = -(2**31), 2**31 - 1
if not bigint_min <= value <= bigint_max:
raise ValueError(f"{value} exceeds the range that Athena cannot handle as bigint.")
if not int_min <= value <= int_max:
value = (value >> 32) ^ value
if value < 0:
return -value - 1
return int(value)
if isinstance(value, (str, np.str_)):
value_hash = 0
for byte in value.encode():
value_hash = value_hash * 31 + byte
value_hash = _simulate_overflow(value_hash)
return value_hash
if isinstance(value, (bool, np.bool_)):
return int(value)
raise exceptions.InvalidDataFrame(
"Column specified for bucketing contains invalid data type. Only string, int and bool are supported."
)
def _get_subgroup_prefix(keys: tuple[str, None], partition_cols: list[str], path_root: str) -> str:
subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)])
return f"{path_root}{subdir}/"
def _delete_objects(
keys: tuple[str, None],
path_root: str,
use_threads: bool | int,
mode: str,
partition_cols: list[str],
boto3_session: boto3.Session | None = None,
**func_kwargs: Any,
) -> str:
# Keys are either a primitive type or a tuple if partitioning by multiple cols
keys = (keys,) if not isinstance(keys, tuple) else keys
prefix = _get_subgroup_prefix(keys, partition_cols, path_root)
if mode == "overwrite_partitions":
delete_objects(
path=prefix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"),
)
return prefix
@engine.dispatch_on_engine
def _to_partitions(
df: pd.DataFrame,
func: Callable[..., list[str]],
concurrent_partitioning: bool,
path_root: str,
use_threads: bool | int,
mode: str,
partition_cols: list[str],
bucketing_info: typing.BucketingInfoTuple | None,
filename_prefix: str,
boto3_session: boto3.Session | None,
**func_kwargs: Any,
) -> tuple[list[str], dict[str, list[str]]]:
partitions_values: dict[str, list[str]] = {}
proxy: _WriteProxy = _WriteProxy(use_threads=concurrent_partitioning)
s3_client = client(service_name="s3", session=boto3_session)
for keys, subgroup in df.groupby(by=partition_cols, observed=True):
# Keys are either a primitive type or a tuple if partitioning by multiple cols
keys = (keys,) if not isinstance(keys, tuple) else keys # noqa: PLW2901
# Drop partition columns from df
subgroup.drop(
columns=[col for col in partition_cols if col in subgroup.columns],
inplace=True,
)
# Drop index levels if partitioning by index columns
subgroup.reset_index(
level=[col for col in partition_cols if col in subgroup.index.names],
drop=True,
inplace=True,
)
prefix = _delete_objects(
keys=keys,
path_root=path_root,
use_threads=use_threads,
mode=mode,
partition_cols=partition_cols,
boto3_session=boto3_session,
**func_kwargs,
)
if bucketing_info:
_to_buckets(
subgroup,
func=func,
path_root=prefix,
bucketing_info=bucketing_info,
boto3_session=boto3_session,
use_threads=use_threads,
proxy=proxy,
filename_prefix=filename_prefix,
**func_kwargs,
)
else:
proxy.write(
func,
subgroup,
path_root=prefix,
filename_prefix=filename_prefix,
s3_client=s3_client,
use_threads=use_threads,
**func_kwargs,
)
partitions_values[prefix] = [str(k) for k in keys]
paths: list[str] = proxy.close() # blocking
return paths, partitions_values
@engine.dispatch_on_engine
def _to_buckets(
df: pd.DataFrame,
func: Callable[..., list[str]],
path_root: str,
bucketing_info: typing.BucketingInfoTuple,
filename_prefix: str,
boto3_session: boto3.Session | None,
use_threads: bool | int,
proxy: _WriteProxy | None = None,
**func_kwargs: Any,
) -> list[str]:
_proxy: _WriteProxy = proxy if proxy else _WriteProxy(use_threads=False)
s3_client = client(service_name="s3", session=boto3_session)
for bucket_number, subgroup in df.groupby(by=_get_bucketing_series(df=df, bucketing_info=bucketing_info)):
_proxy.write(
func,
subgroup,
path_root=path_root,
filename_prefix=f"{filename_prefix}_bucket-{bucket_number:05d}",
use_threads=use_threads,
s3_client=s3_client,
**func_kwargs,
)
if proxy:
return []
paths: list[str] = _proxy.close() # blocking
return paths
def _to_dataset(
func: Callable[..., list[str]],
concurrent_partitioning: bool,
df: pd.DataFrame,
path_root: str,
filename_prefix: str,
index: bool,
use_threads: bool | int,
mode: str,
partition_cols: list[str] | None,
bucketing_info: typing.BucketingInfoTuple | None,
boto3_session: boto3.Session | None,
**func_kwargs: Any,
) -> tuple[list[str], dict[str, list[str]]]:
path_root = path_root if path_root.endswith("/") else f"{path_root}/"
# Evaluate mode
if mode not in ["append", "overwrite", "overwrite_partitions"]:
raise exceptions.InvalidArgumentValue(
f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions."
)
if (mode == "overwrite") or ((mode == "overwrite_partitions") and (not partition_cols)):
delete_objects(path=path_root, use_threads=use_threads, boto3_session=boto3_session)
# Writing
partitions_values: dict[str, list[str]] = {}
paths: list[str]
if partition_cols:
paths, partitions_values = _to_partitions(
df,
func=func,
concurrent_partitioning=concurrent_partitioning,
path_root=path_root,
use_threads=use_threads,
mode=mode,
bucketing_info=bucketing_info,
filename_prefix=filename_prefix,
partition_cols=partition_cols,
boto3_session=boto3_session,
index=index,
**func_kwargs,
)
elif bucketing_info:
paths = _to_buckets(
df,
func=func,
path_root=path_root,
use_threads=use_threads,
bucketing_info=bucketing_info,
filename_prefix=filename_prefix,
boto3_session=boto3_session,
index=index,
**func_kwargs,
)
else:
s3_client = client(service_name="s3", session=boto3_session)
paths = func(
df,
path_root=path_root,
filename_prefix=filename_prefix,
use_threads=use_threads,
index=index,
s3_client=s3_client,
**func_kwargs,
)
_logger.debug("Wrote %s paths", len(paths))
_logger.debug("Created partitions_values: %s", partitions_values)
return paths, partitions_values