awswrangler/distributed/ray/modin/_utils.py (93 lines of code) (raw):

"""Modin on Ray utilities (PRIVATE).""" from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable, cast import modin.pandas as modin_pd import pandas as pd import pyarrow as pa import ray from modin.distributed.dataframe.pandas import from_partitions from ray.data import Dataset, from_modin, from_pandas from ray.data.block import BlockAccessor from ray.types import ObjectRef from awswrangler import engine, exceptions from awswrangler._arrow import _table_to_df from awswrangler._utils import copy_df_shallow from awswrangler.distributed.ray import ray_get, ray_remote @ray_remote() def _block_to_df( block: Any, to_pandas_kwargs: dict[str, Any], ) -> pd.DataFrame: if isinstance(block, pd.DataFrame): return block block = BlockAccessor.for_block(block) return _table_to_df(table=block._table, kwargs=to_pandas_kwargs) def _ray_dataset_from_df(df: pd.DataFrame | modin_pd.DataFrame) -> Dataset: """Create Ray dataset from supported types of data frames.""" if isinstance(df, modin_pd.DataFrame): return from_modin(df) # type: ignore[no-any-return] if isinstance(df, pd.DataFrame): return from_pandas(df) # type: ignore[no-any-return] raise ValueError(f"Unknown DataFrame type: {type(df)}") def _to_modin( dataset: Dataset, to_pandas_kwargs: dict[str, Any] | None = None, ignore_index: bool | None = True, ) -> modin_pd.DataFrame: index = modin_pd.RangeIndex(start=0, stop=dataset.count()) if ignore_index else None _to_pandas_kwargs = {} if to_pandas_kwargs is None else to_pandas_kwargs return from_partitions( partitions=[ _block_to_df(block=block_ref, to_pandas_kwargs=_to_pandas_kwargs) for ref_bundle in dataset.iter_internal_ref_bundles() for block_ref in ref_bundle.block_refs ], axis=0, index=index, ) def _split_modin_frame(df: modin_pd.DataFrame, splits: int) -> list[ObjectRef[Any]]: object_refs: list[ObjectRef[Any]] = [ block_ref for ref_bundle in _ray_dataset_from_df(df).iter_internal_ref_bundles() for block_ref in ref_bundle.block_refs ] return object_refs def _arrow_refs_to_df( arrow_refs: list[Callable[..., Any] | pa.Table], kwargs: dict[str, Any] | None ) -> modin_pd.DataFrame: @ray_remote() def _is_not_empty(table: pa.Table) -> Any: return table.num_rows > 0 or table.num_columns > 0 if isinstance(arrow_refs[0], pa.Table): tables = cast(list[pa.Table], arrow_refs) tables = [table for table in tables if table.num_rows > 0 or table.num_columns > 0] return _to_modin( dataset=ray.data.from_arrow(tables) if len(tables) > 0 else ray.data.from_arrow([pa.Table.from_arrays([])]), to_pandas_kwargs=kwargs, ) ref_rows: list[bool] = ray_get([_is_not_empty(arrow_ref) for arrow_ref in arrow_refs]) refs: list[Callable[..., Any]] = [ref for ref_rows, ref in zip(ref_rows, arrow_refs) if ref_rows] return _to_modin( dataset=ray.data.from_arrow_refs(refs) if len(refs) > 0 else ray.data.from_arrow([pa.Table.from_arrays([])]), to_pandas_kwargs=kwargs, ) def _is_pandas_or_modin_frame(obj: Any) -> bool: return isinstance(obj, (pd.DataFrame, modin_pd.DataFrame)) def _copy_modin_df_shallow(frame: pd.DataFrame | modin_pd.DataFrame) -> pd.DataFrame | modin_pd.DataFrame: if isinstance(frame, pd.DataFrame): engine.dispatch_func(copy_df_shallow, "python")(frame) return modin_pd.DataFrame(frame, copy=False) @dataclass class ParamConfig: """ Configuration for a Pandas argument that is supported in PyArrow. Contains a default value and, optionally, a list of supports values. """ default: Any supported_values: set[Any] | None = None def _check_parameters(pandas_kwargs: dict[str, Any], supported_params: dict[str, ParamConfig]) -> None: for pandas_arg_key, pandas_args_value in pandas_kwargs.items(): if pandas_arg_key not in supported_params: raise exceptions.InvalidArgument(f"Unsupported Pandas parameter for PyArrow loader: {pandas_arg_key}") param_config = supported_params[pandas_arg_key] if param_config.supported_values is None: continue if pandas_args_value not in param_config.supported_values: raise exceptions.InvalidArgument( f"Unsupported Pandas parameter value for PyArrow loader: {pandas_arg_key}={pandas_args_value}", )