def modin_repartition()

in awswrangler/distributed/ray/modin/_core.py [0:0]


def modin_repartition(function: FunctionType) -> FunctionType:
    """
    Decorate callable to repartition Modin data frame.

    By default, repartition along row (axis=0) axis.
    This avoids a situation where columns are split along multiple blocks.

    Parameters
    ----------
    function : Callable[..., Any]
        Callable as input to ray.remote

    Returns
    -------
    Callable[..., Any]
    """
    # Access the source function if it exists
    function = getattr(function, "_source_func", function)

    @wraps(function)
    def wrapper(
        df: pd.DataFrame,
        *args: Any,
        axis: int | None = None,
        row_lengths: int | None = None,
        validate_partitions: bool = True,
        **kwargs: Any,
    ) -> Any:
        # Validate partitions and repartition Modin data frame along row (axis=0) axis
        # to avoid a situation where columns are split along multiple blocks
        if isinstance(df, ModinDataFrame):
            if validate_partitions and not _validate_partition_shape(df):
                _logger.warning(
                    "Partitions of this data frame are detected to be split along column axis. "
                    "The DataFrame will be automatically repartitioned along row axis to ensure "
                    "each partition can be processed independently."
                )
                axis = 0
            if axis is not None:
                df = from_partitions(unwrap_partitions(df, axis=axis), axis=axis, row_lengths=row_lengths)
        return function(df, *args, **kwargs)

    return wrapper  # type: ignore[return-value]