awswrangler/_arrow.py (97 lines of code) (raw):

"""Arrow Utilities Module (PRIVATE).""" from __future__ import annotations import datetime import json import logging from typing import Any, Tuple, cast import pyarrow as pa import pytz import awswrangler.pandas as pd from awswrangler._data_types import athena2pyarrow _logger: logging.Logger = logging.getLogger(__name__) def _extract_partitions_from_path(path_root: str, path: str) -> dict[str, str]: path_root = path_root if path_root.endswith("/") else f"{path_root}/" if path_root not in path: raise Exception(f"Object {path} is not under the root path ({path_root}).") path_wo_filename: str = path.rpartition("/")[0] + "/" path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "") dirs: tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if x and (x.count("=") > 0)) if not dirs: return {} values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=", maxsplit=1)[:2]) for x in dirs)) values_dics: dict[str, str] = dict(values_tups) return values_dics def _add_table_partitions( table: pa.Table, path: str, path_root: str | None, ) -> pa.Table: part = _extract_partitions_from_path(path_root, path) if path_root else None if part: for col, value in part.items(): part_value = pa.array([value] * len(table)).dictionary_encode() if col not in table.schema.names: table = table.append_column(col, part_value) else: table = table.set_column( table.schema.get_field_index(col), col, part_value, ) return table def ensure_df_is_mutable(df: pd.DataFrame) -> pd.DataFrame: """Ensure that all columns has the writeable flag True.""" for column in df.columns.to_list(): if hasattr(df[column].values, "flags") is True: if df[column].values.flags.writeable is False: s: pd.Series = df[column] df[column] = None df[column] = s return df def _apply_timezone(df: pd.DataFrame, metadata: dict[str, Any]) -> pd.DataFrame: for c in metadata["columns"]: if "field_name" in c and c["field_name"] is not None: col_name = str(c["field_name"]) elif "name" in c and c["name"] is not None: col_name = str(c["name"]) else: continue if col_name in df.columns and c["pandas_type"] == "datetimetz": column_metadata: dict[str, Any] = c["metadata"] if c.get("metadata") else {} timezone_str: str | None = column_metadata.get("timezone") if timezone_str: timezone: datetime.tzinfo = pa.lib.string_to_tzinfo(timezone_str) _logger.debug("applying timezone (%s) on column %s", timezone, col_name) if hasattr(df[col_name].dt, "tz") is False or df[col_name].dt.tz is None: df[col_name] = df[col_name].dt.tz_localize(tz="UTC") if timezone is not None and timezone != pytz.UTC and hasattr(df[col_name].dt, "tz_convert"): df[col_name] = df[col_name].dt.tz_convert(tz=timezone) return df def _table_to_df( table: pa.Table, kwargs: dict[str, Any], ) -> pd.DataFrame: """Convert a PyArrow table to a Pandas DataFrame and apply metadata. This method should be used across to codebase to ensure this conversion is consistent. """ metadata: dict[str, Any] = {} if table.schema.metadata is not None and b"pandas" in table.schema.metadata: metadata = json.loads(table.schema.metadata[b"pandas"]) df = table.to_pandas(**kwargs) df = ensure_df_is_mutable(df=df) if metadata: _logger.debug("metadata: %s", metadata) df = _apply_timezone(df=df, metadata=metadata) return df def _df_to_table( df: pd.DataFrame, schema: pa.Schema | None = None, index: bool | None = None, dtype: dict[str, str] | None = None, cpus: int | None = None, ) -> pa.Table: table: pa.Table = pa.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True) if dtype: for col_name, col_type in dtype.items(): if col_name in table.column_names: col_index = table.column_names.index(col_name) pyarrow_dtype = athena2pyarrow(col_type, df.dtypes.get(col_name)) field = pa.field(name=col_name, type=pyarrow_dtype) table = table.set_column(col_index, field, table.column(col_name).cast(pyarrow_dtype)) _logger.debug("Casting column %s (%s) to %s (%s)", col_name, col_index, col_type, pyarrow_dtype) return table