awswrangler/distributed/ray/modin/_utils.py (93 lines of code) (raw):
"""Modin on Ray utilities (PRIVATE)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, cast
import modin.pandas as modin_pd
import pandas as pd
import pyarrow as pa
import ray
from modin.distributed.dataframe.pandas import from_partitions
from ray.data import Dataset, from_modin, from_pandas
from ray.data.block import BlockAccessor
from ray.types import ObjectRef
from awswrangler import engine, exceptions
from awswrangler._arrow import _table_to_df
from awswrangler._utils import copy_df_shallow
from awswrangler.distributed.ray import ray_get, ray_remote
@ray_remote()
def _block_to_df(
block: Any,
to_pandas_kwargs: dict[str, Any],
) -> pd.DataFrame:
if isinstance(block, pd.DataFrame):
return block
block = BlockAccessor.for_block(block)
return _table_to_df(table=block._table, kwargs=to_pandas_kwargs)
def _ray_dataset_from_df(df: pd.DataFrame | modin_pd.DataFrame) -> Dataset:
"""Create Ray dataset from supported types of data frames."""
if isinstance(df, modin_pd.DataFrame):
return from_modin(df) # type: ignore[no-any-return]
if isinstance(df, pd.DataFrame):
return from_pandas(df) # type: ignore[no-any-return]
raise ValueError(f"Unknown DataFrame type: {type(df)}")
def _to_modin(
dataset: Dataset,
to_pandas_kwargs: dict[str, Any] | None = None,
ignore_index: bool | None = True,
) -> modin_pd.DataFrame:
index = modin_pd.RangeIndex(start=0, stop=dataset.count()) if ignore_index else None
_to_pandas_kwargs = {} if to_pandas_kwargs is None else to_pandas_kwargs
return from_partitions(
partitions=[
_block_to_df(block=block_ref, to_pandas_kwargs=_to_pandas_kwargs)
for ref_bundle in dataset.iter_internal_ref_bundles()
for block_ref in ref_bundle.block_refs
],
axis=0,
index=index,
)
def _split_modin_frame(df: modin_pd.DataFrame, splits: int) -> list[ObjectRef[Any]]:
object_refs: list[ObjectRef[Any]] = [
block_ref
for ref_bundle in _ray_dataset_from_df(df).iter_internal_ref_bundles()
for block_ref in ref_bundle.block_refs
]
return object_refs
def _arrow_refs_to_df(
arrow_refs: list[Callable[..., Any] | pa.Table], kwargs: dict[str, Any] | None
) -> modin_pd.DataFrame:
@ray_remote()
def _is_not_empty(table: pa.Table) -> Any:
return table.num_rows > 0 or table.num_columns > 0
if isinstance(arrow_refs[0], pa.Table):
tables = cast(list[pa.Table], arrow_refs)
tables = [table for table in tables if table.num_rows > 0 or table.num_columns > 0]
return _to_modin(
dataset=ray.data.from_arrow(tables) if len(tables) > 0 else ray.data.from_arrow([pa.Table.from_arrays([])]),
to_pandas_kwargs=kwargs,
)
ref_rows: list[bool] = ray_get([_is_not_empty(arrow_ref) for arrow_ref in arrow_refs])
refs: list[Callable[..., Any]] = [ref for ref_rows, ref in zip(ref_rows, arrow_refs) if ref_rows]
return _to_modin(
dataset=ray.data.from_arrow_refs(refs) if len(refs) > 0 else ray.data.from_arrow([pa.Table.from_arrays([])]),
to_pandas_kwargs=kwargs,
)
def _is_pandas_or_modin_frame(obj: Any) -> bool:
return isinstance(obj, (pd.DataFrame, modin_pd.DataFrame))
def _copy_modin_df_shallow(frame: pd.DataFrame | modin_pd.DataFrame) -> pd.DataFrame | modin_pd.DataFrame:
if isinstance(frame, pd.DataFrame):
engine.dispatch_func(copy_df_shallow, "python")(frame)
return modin_pd.DataFrame(frame, copy=False)
@dataclass
class ParamConfig:
"""
Configuration for a Pandas argument that is supported in PyArrow.
Contains a default value and, optionally, a list of supports values.
"""
default: Any
supported_values: set[Any] | None = None
def _check_parameters(pandas_kwargs: dict[str, Any], supported_params: dict[str, ParamConfig]) -> None:
for pandas_arg_key, pandas_args_value in pandas_kwargs.items():
if pandas_arg_key not in supported_params:
raise exceptions.InvalidArgument(f"Unsupported Pandas parameter for PyArrow loader: {pandas_arg_key}")
param_config = supported_params[pandas_arg_key]
if param_config.supported_values is None:
continue
if pandas_args_value not in param_config.supported_values:
raise exceptions.InvalidArgument(
f"Unsupported Pandas parameter value for PyArrow loader: {pandas_arg_key}={pandas_args_value}",
)