awswrangler/s3/_write_dataset.py (231 lines of code) (raw):

"""Amazon S3 Write Dataset (PRIVATE).""" from __future__ import annotations import logging from typing import Any, Callable import boto3 import numpy as np import pandas as pd from awswrangler import exceptions, typing from awswrangler._distributed import engine from awswrangler._utils import client from awswrangler.s3._delete import delete_objects from awswrangler.s3._write_concurrent import _WriteProxy _logger: logging.Logger = logging.getLogger(__name__) def _get_bucketing_series(df: pd.DataFrame, bucketing_info: typing.BucketingInfoTuple) -> pd.Series: bucket_number_series = ( df[bucketing_info[0]] # Prevent "upcasting" mixed types by casting to object .astype("O") .apply( lambda row: _get_bucket_number(bucketing_info[1], [row[col_name] for col_name in bucketing_info[0]]), axis="columns", ) ) return bucket_number_series.astype(np.array([pd.CategoricalDtype(range(bucketing_info[1]))])) def _simulate_overflow(value: int, bits: int = 31, signed: bool = False) -> int: base = 1 << bits value %= base return value - base if signed and value.bit_length() == bits else value def _get_bucket_number(number_of_buckets: int, values: list[str | int | bool]) -> int: hash_code = 0 for value in values: hash_code = 31 * hash_code + _get_value_hash(value) hash_code = _simulate_overflow(hash_code) return hash_code % number_of_buckets def _get_value_hash(value: str | int | bool) -> int: if isinstance(value, (int, np.int_)): value = int(value) bigint_min, bigint_max = -(2**63), 2**63 - 1 int_min, int_max = -(2**31), 2**31 - 1 if not bigint_min <= value <= bigint_max: raise ValueError(f"{value} exceeds the range that Athena cannot handle as bigint.") if not int_min <= value <= int_max: value = (value >> 32) ^ value if value < 0: return -value - 1 return int(value) if isinstance(value, (str, np.str_)): value_hash = 0 for byte in value.encode(): value_hash = value_hash * 31 + byte value_hash = _simulate_overflow(value_hash) return value_hash if isinstance(value, (bool, np.bool_)): return int(value) raise exceptions.InvalidDataFrame( "Column specified for bucketing contains invalid data type. Only string, int and bool are supported." ) def _get_subgroup_prefix(keys: tuple[str, None], partition_cols: list[str], path_root: str) -> str: subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)]) return f"{path_root}{subdir}/" def _delete_objects( keys: tuple[str, None], path_root: str, use_threads: bool | int, mode: str, partition_cols: list[str], boto3_session: boto3.Session | None = None, **func_kwargs: Any, ) -> str: # Keys are either a primitive type or a tuple if partitioning by multiple cols keys = (keys,) if not isinstance(keys, tuple) else keys prefix = _get_subgroup_prefix(keys, partition_cols, path_root) if mode == "overwrite_partitions": delete_objects( path=prefix, use_threads=use_threads, boto3_session=boto3_session, s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"), ) return prefix @engine.dispatch_on_engine def _to_partitions( df: pd.DataFrame, func: Callable[..., list[str]], concurrent_partitioning: bool, path_root: str, use_threads: bool | int, mode: str, partition_cols: list[str], bucketing_info: typing.BucketingInfoTuple | None, filename_prefix: str, boto3_session: boto3.Session | None, **func_kwargs: Any, ) -> tuple[list[str], dict[str, list[str]]]: partitions_values: dict[str, list[str]] = {} proxy: _WriteProxy = _WriteProxy(use_threads=concurrent_partitioning) s3_client = client(service_name="s3", session=boto3_session) for keys, subgroup in df.groupby(by=partition_cols, observed=True): # Keys are either a primitive type or a tuple if partitioning by multiple cols keys = (keys,) if not isinstance(keys, tuple) else keys # noqa: PLW2901 # Drop partition columns from df subgroup.drop( columns=[col for col in partition_cols if col in subgroup.columns], inplace=True, ) # Drop index levels if partitioning by index columns subgroup.reset_index( level=[col for col in partition_cols if col in subgroup.index.names], drop=True, inplace=True, ) prefix = _delete_objects( keys=keys, path_root=path_root, use_threads=use_threads, mode=mode, partition_cols=partition_cols, boto3_session=boto3_session, **func_kwargs, ) if bucketing_info: _to_buckets( subgroup, func=func, path_root=prefix, bucketing_info=bucketing_info, boto3_session=boto3_session, use_threads=use_threads, proxy=proxy, filename_prefix=filename_prefix, **func_kwargs, ) else: proxy.write( func, subgroup, path_root=prefix, filename_prefix=filename_prefix, s3_client=s3_client, use_threads=use_threads, **func_kwargs, ) partitions_values[prefix] = [str(k) for k in keys] paths: list[str] = proxy.close() # blocking return paths, partitions_values @engine.dispatch_on_engine def _to_buckets( df: pd.DataFrame, func: Callable[..., list[str]], path_root: str, bucketing_info: typing.BucketingInfoTuple, filename_prefix: str, boto3_session: boto3.Session | None, use_threads: bool | int, proxy: _WriteProxy | None = None, **func_kwargs: Any, ) -> list[str]: _proxy: _WriteProxy = proxy if proxy else _WriteProxy(use_threads=False) s3_client = client(service_name="s3", session=boto3_session) for bucket_number, subgroup in df.groupby(by=_get_bucketing_series(df=df, bucketing_info=bucketing_info)): _proxy.write( func, subgroup, path_root=path_root, filename_prefix=f"{filename_prefix}_bucket-{bucket_number:05d}", use_threads=use_threads, s3_client=s3_client, **func_kwargs, ) if proxy: return [] paths: list[str] = _proxy.close() # blocking return paths def _to_dataset( func: Callable[..., list[str]], concurrent_partitioning: bool, df: pd.DataFrame, path_root: str, filename_prefix: str, index: bool, use_threads: bool | int, mode: str, partition_cols: list[str] | None, bucketing_info: typing.BucketingInfoTuple | None, boto3_session: boto3.Session | None, **func_kwargs: Any, ) -> tuple[list[str], dict[str, list[str]]]: path_root = path_root if path_root.endswith("/") else f"{path_root}/" # Evaluate mode if mode not in ["append", "overwrite", "overwrite_partitions"]: raise exceptions.InvalidArgumentValue( f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions." ) if (mode == "overwrite") or ((mode == "overwrite_partitions") and (not partition_cols)): delete_objects(path=path_root, use_threads=use_threads, boto3_session=boto3_session) # Writing partitions_values: dict[str, list[str]] = {} paths: list[str] if partition_cols: paths, partitions_values = _to_partitions( df, func=func, concurrent_partitioning=concurrent_partitioning, path_root=path_root, use_threads=use_threads, mode=mode, bucketing_info=bucketing_info, filename_prefix=filename_prefix, partition_cols=partition_cols, boto3_session=boto3_session, index=index, **func_kwargs, ) elif bucketing_info: paths = _to_buckets( df, func=func, path_root=path_root, use_threads=use_threads, bucketing_info=bucketing_info, filename_prefix=filename_prefix, boto3_session=boto3_session, index=index, **func_kwargs, ) else: s3_client = client(service_name="s3", session=boto3_session) paths = func( df, path_root=path_root, filename_prefix=filename_prefix, use_threads=use_threads, index=index, s3_client=s3_client, **func_kwargs, ) _logger.debug("Wrote %s paths", len(paths)) _logger.debug("Created partitions_values: %s", partitions_values) return paths, partitions_values