awswrangler/s3/_read.py (336 lines of code) (raw):
"""Amazon S3 Read Module (PRIVATE)."""
from __future__ import annotations
import itertools
import logging
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
NamedTuple,
Tuple,
cast,
)
import boto3
import numpy as np
import pandas as pd
import pyarrow as pa
from pandas.api.types import union_categoricals
from awswrangler import _data_types, _utils, exceptions
from awswrangler._arrow import _extract_partitions_from_path
from awswrangler._executor import _BaseExecutor, _get_executor
from awswrangler.catalog._get import _get_partitions
from awswrangler.catalog._utils import _catalog_id
from awswrangler.distributed.ray import ray_get
from awswrangler.s3._list import _path2list, _prefix_cleanup
from awswrangler.typing import RaySettings
if TYPE_CHECKING:
from mypy_boto3_glue.type_defs import GetTableResponseTypeDef
from mypy_boto3_s3 import S3Client
_logger: logging.Logger = logging.getLogger(__name__)
def _get_path_root(path: str | list[str], dataset: bool) -> str | None:
if (dataset is True) and (not isinstance(path, str)):
raise exceptions.InvalidArgument("The path argument must be a string if dataset=True (Amazon S3 prefix).")
return _prefix_cleanup(str(path)) if dataset is True else None
def _get_path_ignore_suffix(path_ignore_suffix: str | list[str] | None) -> list[str] | None:
if isinstance(path_ignore_suffix, str):
path_ignore_suffix = [path_ignore_suffix, "/_SUCCESS"]
elif path_ignore_suffix is None:
path_ignore_suffix = ["/_SUCCESS"]
else:
path_ignore_suffix = path_ignore_suffix + ["/_SUCCESS"]
return path_ignore_suffix
def _extract_partitions_metadata_from_paths(
path: str, paths: list[str]
) -> tuple[dict[str, str] | None, dict[str, list[str]] | None]:
"""Extract partitions metadata from Amazon S3 paths."""
path = path if path.endswith("/") else f"{path}/"
partitions_types: dict[str, str] = {}
partitions_values: dict[str, list[str]] = {}
for p in paths:
if path not in p:
raise exceptions.InvalidArgumentValue(f"Object {p} is not under the root path ({path}).")
path_wo_filename: str = p.rpartition("/")[0] + "/"
if path_wo_filename not in partitions_values:
path_wo_prefix: str = path_wo_filename.replace(f"{path}", "")
dirs: tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if x and (x.count("=") > 0))
if dirs:
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=", maxsplit=1)[:2]) for x in dirs))
values_dics: dict[str, str] = dict(values_tups)
p_values: list[str] = list(values_dics.values())
p_types: dict[str, str] = {x: "string" for x in values_dics.keys()}
if not partitions_types:
partitions_types = p_types
if p_values:
partitions_types = p_types
partitions_values[path_wo_filename] = p_values
elif p_types != partitions_types:
raise exceptions.InvalidSchemaConvergence(
f"At least two different partitions schema detected: {partitions_types} and {p_types}"
)
if not partitions_types:
return None, None
return partitions_types, partitions_values
def _apply_partition_filter(
path_root: str, paths: list[str], filter_func: Callable[[dict[str, str]], bool] | None
) -> list[str]:
if filter_func is None:
return paths
return [p for p in paths if filter_func(_extract_partitions_from_path(path_root=path_root, path=p)) is True]
def _apply_partitions(df: pd.DataFrame, dataset: bool, path: str, path_root: str | None) -> pd.DataFrame:
if dataset is False:
return df
if dataset is True and path_root is None:
raise exceptions.InvalidArgument("A path_root is required when dataset=True.")
partitions: dict[str, str] = _extract_partitions_from_path(path_root=path_root, path=path)
_logger.debug("partitions: %s", partitions)
count: int = len(df.index)
_logger.debug("count: %s", count)
for name, value in partitions.items():
df[name] = pd.Categorical.from_codes(np.repeat([0], count), categories=[value])
return df
def _extract_partitions_dtypes_from_table_details(response: "GetTableResponseTypeDef") -> dict[str, str]:
dtypes: dict[str, str] = {}
for par in response["Table"].get("PartitionKeys", []):
dtypes[par["Name"]] = par["Type"]
return dtypes
def _concat_union_categoricals(dfs: list[pd.DataFrame], ignore_index: bool) -> pd.DataFrame:
"""Concatenate dataframes with union of categorical columns."""
cats: tuple[set[str], ...] = tuple(set(df.select_dtypes(include="category").columns) for df in dfs)
for col in set.intersection(*cats):
cat = union_categoricals([df[col] for df in dfs])
for df in dfs:
df[col] = pd.Categorical(df[col].values, categories=cat.categories)
return pd.concat(objs=dfs, sort=False, copy=False, ignore_index=ignore_index)
def _check_version_id(paths: list[str], version_id: str | dict[str, str] | None = None) -> dict[str, str] | None:
if len(paths) > 1 and version_id is not None and not isinstance(version_id, dict):
raise exceptions.InvalidArgumentCombination(
"If multiple paths are provided along with a file version ID, the version ID parameter must be a dict."
)
if isinstance(version_id, dict) and not all(version_id.values()):
raise exceptions.InvalidArgumentValue("Values in version ID dict cannot be None.")
return (
version_id if isinstance(version_id, dict) else {paths[0]: version_id} if isinstance(version_id, str) else None
)
class _InternalReadTableMetadataReturnValue(NamedTuple):
columns_types: dict[str, str]
partitions_types: dict[str, str] | None
partitions_values: dict[str, list[str]] | None
class _TableMetadataReader(ABC):
@abstractmethod
def _read_metadata_file(
self,
s3_client: "S3Client" | None,
path: str,
s3_additional_kwargs: dict[str, str] | None,
use_threads: bool | int,
version_id: str | None = None,
coerce_int96_timestamp_unit: str | None = None,
) -> pa.schema:
pass
def _read_schemas_from_files(
self,
paths: list[str],
sampling: float,
use_threads: bool | int,
s3_client: "S3Client",
s3_additional_kwargs: dict[str, str] | None,
version_ids: dict[str, str] | None,
coerce_int96_timestamp_unit: str | None = None,
) -> list[pa.schema]:
paths = _utils.list_sampling(lst=paths, sampling=sampling)
executor: _BaseExecutor = _get_executor(use_threads=use_threads)
schemas = ray_get(
executor.map(
self._read_metadata_file,
s3_client,
paths,
itertools.repeat(s3_additional_kwargs),
itertools.repeat(use_threads),
[version_ids.get(p) if isinstance(version_ids, dict) else None for p in paths],
itertools.repeat(coerce_int96_timestamp_unit),
)
)
return [schema for schema in schemas if schema is not None]
def _validate_schemas_from_files(
self,
validate_schema: bool,
paths: list[str],
sampling: float,
use_threads: bool | int,
s3_client: "S3Client",
s3_additional_kwargs: dict[str, str] | None,
version_ids: dict[str, str] | None,
coerce_int96_timestamp_unit: str | None = None,
) -> pa.schema:
schemas: list[pa.schema] = self._read_schemas_from_files(
paths=paths,
sampling=sampling,
use_threads=use_threads,
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
version_ids=version_ids,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
)
return _validate_schemas(schemas, validate_schema)
def validate_schemas(
self,
paths: list[str],
path_root: str | None,
columns: list[str] | None,
validate_schema: bool,
s3_client: "S3Client",
version_ids: dict[str, str] | None = None,
use_threads: bool | int = True,
coerce_int96_timestamp_unit: str | None = None,
s3_additional_kwargs: dict[str, Any] | None = None,
) -> pa.schema:
schema = self._validate_schemas_from_files(
validate_schema=validate_schema,
paths=paths,
sampling=1.0,
use_threads=use_threads,
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
version_ids=version_ids,
)
if path_root:
partition_types, _ = _extract_partitions_metadata_from_paths(path=path_root, paths=paths)
if partition_types:
partition_schema = pa.schema(
fields={k: _data_types.athena2pyarrow(dtype=v) for k, v in partition_types.items()}
)
schema = pa.unify_schemas([schema, partition_schema])
if columns:
schema = pa.schema([schema.field(column) for column in columns], schema.metadata)
_logger.debug("Resolved pyarrow schema:\n%s", schema)
return schema
def read_table_metadata(
self,
path: str | list[str],
path_suffix: str | None,
path_ignore_suffix: str | list[str] | None,
ignore_empty: bool,
ignore_null: bool,
dtype: dict[str, str] | None,
sampling: float,
dataset: bool,
use_threads: bool | int,
boto3_session: boto3.Session | None,
s3_additional_kwargs: dict[str, str] | None,
version_id: str | dict[str, str] | None = None,
coerce_int96_timestamp_unit: str | None = None,
) -> _InternalReadTableMetadataReturnValue:
"""Handle table metadata internally."""
s3_client = _utils.client(service_name="s3", session=boto3_session)
path_root: str | None = _get_path_root(path=path, dataset=dataset)
paths: list[str] = _path2list(
path=path,
s3_client=s3_client,
suffix=path_suffix,
ignore_suffix=_get_path_ignore_suffix(path_ignore_suffix=path_ignore_suffix),
ignore_empty=ignore_empty,
s3_additional_kwargs=s3_additional_kwargs,
)
if len(paths) < 1:
raise exceptions.NoFilesFound(f"No files Found: {path}.")
version_ids = _check_version_id(paths=paths, version_id=version_id)
# Files
schemas: list[pa.schema] = self._read_schemas_from_files(
paths=paths,
sampling=sampling,
use_threads=use_threads,
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
version_ids=version_ids,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
)
merged_schemas = _validate_schemas(schemas=schemas, validate_schema=False)
columns_types: dict[str, str] = _data_types.athena_types_from_pyarrow_schema(
schema=merged_schemas, ignore_null=ignore_null
)
# Partitions
partitions_types: dict[str, str] | None = None
partitions_values: dict[str, list[str]] | None = None
if (dataset is True) and (path_root is not None):
partitions_types, partitions_values = _extract_partitions_metadata_from_paths(path=path_root, paths=paths)
# Casting
if dtype:
for k, v in dtype.items():
if columns_types and k in columns_types:
columns_types[k] = v
if partitions_types and k in partitions_types:
partitions_types[k] = v
return _InternalReadTableMetadataReturnValue(columns_types, partitions_types, partitions_values)
def _validate_schemas(schemas: list[pa.schema], validate_schema: bool) -> pa.schema:
first: pa.schema = schemas[0]
if len(schemas) == 1:
return first
first_dict = {s.name: s.type for s in first}
if validate_schema:
for schema in schemas[1:]:
if first_dict != {s.name: s.type for s in schema}:
raise exceptions.InvalidSchemaConvergence(
f"At least 2 different schemas were detected:\n 1 - {first}\n 2 - {schema}."
)
return pa.unify_schemas(schemas)
def _ensure_locations_are_valid(paths: Iterable[str]) -> Iterator[str]:
for path in paths:
suffix: str = path.rpartition("/")[2]
# If the suffix looks like a partition,
if suffix and (suffix.count("=") == 1):
# the path should end in a '/' character.
path = f"{path}/" # noqa: PLW2901
yield path
def _get_paths_for_glue_table(
table: str,
database: str,
filename_suffix: str | list[str] | None = None,
filename_ignore_suffix: str | list[str] | None = None,
catalog_id: str | None = None,
partition_filter: Callable[[dict[str, str]], bool] | None = None,
boto3_session: boto3.Session | None = None,
s3_additional_kwargs: dict[str, Any] | None = None,
) -> tuple[str | list[str], str | None, "GetTableResponseTypeDef"]:
client_glue = _utils.client(service_name="glue", session=boto3_session)
s3_client = _utils.client(service_name="s3", session=boto3_session)
res = client_glue.get_table(**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table))
try:
location: str = res["Table"]["StorageDescriptor"]["Location"]
path: str = location if location.endswith("/") else f"{location}/"
except KeyError as ex:
raise exceptions.InvalidTable(f"Missing s3 location for {database}.{table}.") from ex
path_root: str | None = None
paths: str | list[str] = path
# If filter is available, fetch & filter out partitions
# Then list objects & process individual object keys under path_root
if partition_filter:
available_partitions_dict = _get_partitions(
database=database,
table=table,
catalog_id=catalog_id,
boto3_session=boto3_session,
)
available_partitions = list(_ensure_locations_are_valid(available_partitions_dict.keys()))
if available_partitions:
paths = []
path_root = path
partitions: str | list[str] = _apply_partition_filter(
path_root=path_root, paths=available_partitions, filter_func=partition_filter
)
for partition in partitions:
paths += _path2list(
path=partition,
s3_client=s3_client,
suffix=filename_suffix,
ignore_suffix=_get_path_ignore_suffix(path_ignore_suffix=filename_ignore_suffix),
s3_additional_kwargs=s3_additional_kwargs,
)
return paths, path_root, res
def _get_num_output_blocks(
ray_args: RaySettings | None = None,
) -> int:
ray_args = ray_args or {}
parallelism = ray_args.get("parallelism", -1)
override_num_blocks = ray_args.get("override_num_blocks")
if parallelism != -1:
pass
_logger.warning(
"The argument ``parallelism`` is deprecated and will be removed in the next major release. "
"Please specify ``override_num_blocks`` instead."
)
elif override_num_blocks is not None:
parallelism = override_num_blocks
return parallelism