awswrangler/_arrow.py (97 lines of code) (raw):
"""Arrow Utilities Module (PRIVATE)."""
from __future__ import annotations
import datetime
import json
import logging
from typing import Any, Tuple, cast
import pyarrow as pa
import pytz
import awswrangler.pandas as pd
from awswrangler._data_types import athena2pyarrow
_logger: logging.Logger = logging.getLogger(__name__)
def _extract_partitions_from_path(path_root: str, path: str) -> dict[str, str]:
path_root = path_root if path_root.endswith("/") else f"{path_root}/"
if path_root not in path:
raise Exception(f"Object {path} is not under the root path ({path_root}).")
path_wo_filename: str = path.rpartition("/")[0] + "/"
path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "")
dirs: tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if x and (x.count("=") > 0))
if not dirs:
return {}
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=", maxsplit=1)[:2]) for x in dirs))
values_dics: dict[str, str] = dict(values_tups)
return values_dics
def _add_table_partitions(
table: pa.Table,
path: str,
path_root: str | None,
) -> pa.Table:
part = _extract_partitions_from_path(path_root, path) if path_root else None
if part:
for col, value in part.items():
part_value = pa.array([value] * len(table)).dictionary_encode()
if col not in table.schema.names:
table = table.append_column(col, part_value)
else:
table = table.set_column(
table.schema.get_field_index(col),
col,
part_value,
)
return table
def ensure_df_is_mutable(df: pd.DataFrame) -> pd.DataFrame:
"""Ensure that all columns has the writeable flag True."""
for column in df.columns.to_list():
if hasattr(df[column].values, "flags") is True:
if df[column].values.flags.writeable is False:
s: pd.Series = df[column]
df[column] = None
df[column] = s
return df
def _apply_timezone(df: pd.DataFrame, metadata: dict[str, Any]) -> pd.DataFrame:
for c in metadata["columns"]:
if "field_name" in c and c["field_name"] is not None:
col_name = str(c["field_name"])
elif "name" in c and c["name"] is not None:
col_name = str(c["name"])
else:
continue
if col_name in df.columns and c["pandas_type"] == "datetimetz":
column_metadata: dict[str, Any] = c["metadata"] if c.get("metadata") else {}
timezone_str: str | None = column_metadata.get("timezone")
if timezone_str:
timezone: datetime.tzinfo = pa.lib.string_to_tzinfo(timezone_str)
_logger.debug("applying timezone (%s) on column %s", timezone, col_name)
if hasattr(df[col_name].dt, "tz") is False or df[col_name].dt.tz is None:
df[col_name] = df[col_name].dt.tz_localize(tz="UTC")
if timezone is not None and timezone != pytz.UTC and hasattr(df[col_name].dt, "tz_convert"):
df[col_name] = df[col_name].dt.tz_convert(tz=timezone)
return df
def _table_to_df(
table: pa.Table,
kwargs: dict[str, Any],
) -> pd.DataFrame:
"""Convert a PyArrow table to a Pandas DataFrame and apply metadata.
This method should be used across to codebase to ensure this conversion is consistent.
"""
metadata: dict[str, Any] = {}
if table.schema.metadata is not None and b"pandas" in table.schema.metadata:
metadata = json.loads(table.schema.metadata[b"pandas"])
df = table.to_pandas(**kwargs)
df = ensure_df_is_mutable(df=df)
if metadata:
_logger.debug("metadata: %s", metadata)
df = _apply_timezone(df=df, metadata=metadata)
return df
def _df_to_table(
df: pd.DataFrame,
schema: pa.Schema | None = None,
index: bool | None = None,
dtype: dict[str, str] | None = None,
cpus: int | None = None,
) -> pa.Table:
table: pa.Table = pa.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True)
if dtype:
for col_name, col_type in dtype.items():
if col_name in table.column_names:
col_index = table.column_names.index(col_name)
pyarrow_dtype = athena2pyarrow(col_type, df.dtypes.get(col_name))
field = pa.field(name=col_name, type=pyarrow_dtype)
table = table.set_column(col_index, field, table.column(col_name).cast(pyarrow_dtype))
_logger.debug("Casting column %s (%s) to %s (%s)", col_name, col_index, col_type, pyarrow_dtype)
return table