awswrangler/distributed/ray/datasources/file_datasink.py (73 lines of code) (raw):
"""Ray FileDatasink Module."""
from __future__ import annotations
import logging
import posixpath
from typing import Any, Iterable
import pandas as pd
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor
from ray.data.datasource.datasink import Datasink
from ray.data.datasource.filename_provider import FilenameProvider
from ray.types import ObjectRef
from awswrangler.distributed.ray import ray_get
from awswrangler.distributed.ray.datasources.filename_provider import _DefaultFilenameProvider
from awswrangler.s3._fs import open_s3_object
_logger: logging.Logger = logging.getLogger(__name__)
class _BlockFileDatasink(Datasink):
def __init__(
self,
path: str,
file_format: str,
*,
filename_provider: FilenameProvider | None = None,
dataset_uuid: str | None = None,
open_s3_object_args: dict[str, Any] | None = None,
pandas_kwargs: dict[str, Any] | None = None,
**write_args: Any,
):
self.path = path
self.file_format = file_format
self.dataset_uuid = dataset_uuid
self.open_s3_object_args = open_s3_object_args or {}
self.pandas_kwargs = pandas_kwargs or {}
self.write_args = write_args or {}
if filename_provider is None:
compression = self.pandas_kwargs.get("compression", None)
bucket_id = self.write_args.get("bucket_id", None)
filename_provider = _DefaultFilenameProvider(
dataset_uuid=dataset_uuid,
file_format=file_format,
compression=compression,
bucket_id=bucket_id,
)
self.filename_provider = filename_provider
self._write_paths: list[str] = []
def write(
self,
blocks: Iterable[Block | ObjectRef[pd.DataFrame]],
ctx: TaskContext,
) -> Any:
_write_block_to_file = self.write_block
def _write_block(write_path: str, block: pd.DataFrame) -> str:
with open_s3_object(
path=write_path,
**self.open_s3_object_args,
) as f:
_write_block_to_file(f, BlockAccessor.for_block(block))
return write_path
builder = DelegatingBlockBuilder() # type: ignore[no-untyped-call]
for block in blocks:
# Dereference the block if ObjectRef is passed
builder.add_block(ray_get(block) if isinstance(block, ObjectRef) else block) # type: ignore[arg-type]
block = builder.build()
write_path = self.path
if write_path.endswith("/"):
filename = self.filename_provider.get_filename_for_block(block, ctx.task_idx, 0)
write_path = posixpath.join(self.path, filename)
return _write_block(write_path, block)
def write_block(self, file: Any, block: BlockAccessor) -> None:
raise NotImplementedError
# Note: this callback function is called once by the main thread after
# [all write tasks complete](https://github.com/ray-project/ray/blob/ray-2.3.0/python/ray/data/dataset.py#L2716)
# and is meant to be used for singular actions like
# [committing a transaction](https://docs.ray.io/en/latest/data/api/doc/ray.data.Datasource.html).
# As deceptive as it may look, there is no race condition here.
def on_write_complete(self, write_results: list[Any], **_: Any) -> None:
"""Execute callback after all write tasks complete."""
_logger.debug("Write complete %s.", write_results)
# Collect and return all write task paths
self._write_paths.extend(write_results)
def get_write_paths(self) -> list[str]:
"""Return S3 paths of where the results have been written."""
return self._write_paths