awswrangler/distributed/ray/datasources/arrow_parquet_datasink.py (66 lines of code) (raw):

"""Ray ArrowParquetDatasink Module.""" from __future__ import annotations import logging from typing import Any import pyarrow as pa from ray.data.block import BlockAccessor from ray.data.datasource.filename_provider import FilenameProvider from awswrangler._arrow import _df_to_table from awswrangler.distributed.ray.datasources.file_datasink import _BlockFileDatasink from awswrangler.distributed.ray.datasources.filename_provider import _DefaultFilenameProvider from awswrangler.s3._write import _COMPRESSION_2_EXT _logger: logging.Logger = logging.getLogger(__name__) class _ParquetFilenameProvider(_DefaultFilenameProvider): """Parquet filename provider where compression comes before file format.""" def _generate_filename(self, file_id: str) -> str: filename = "" if self._dataset_uuid is not None: filename += f"{self._dataset_uuid}_" filename += f"{file_id}" if self._bucket_id is not None: filename += f"_bucket-{self._bucket_id:05d}" filename += f"{_COMPRESSION_2_EXT.get(self._compression)}.{self._file_format}" return filename class ArrowParquetDatasink(_BlockFileDatasink): """A datasink that writes Parquet files.""" def __init__( self, path: str, *, filename_provider: FilenameProvider | None = None, dataset_uuid: str | None = None, open_s3_object_args: dict[str, Any] | None = None, pandas_kwargs: dict[str, Any] | None = None, schema: pa.Schema | None = None, index: bool = False, dtype: dict[str, str] | None = None, pyarrow_additional_kwargs: dict[str, Any] | None = None, compression: str | None = None, **write_args: Any, ): file_format = "parquet" write_args = write_args or {} if filename_provider is None: bucket_id = write_args.get("bucket_id", None) filename_provider = _ParquetFilenameProvider( dataset_uuid=dataset_uuid, file_format=file_format, compression=compression, bucket_id=bucket_id, ) super().__init__( path, file_format=file_format, filename_provider=filename_provider, dataset_uuid=dataset_uuid, open_s3_object_args=open_s3_object_args, pandas_kwargs=pandas_kwargs, **write_args, ) self.pyarrow_additional_kwargs = pyarrow_additional_kwargs or {} self.schema = schema self.index = index self.dtype = dtype def write_block(self, file: pa.NativeFile, block: BlockAccessor) -> None: """ Write a block of data to a file. Parameters ---------- file : pa.NativeFile block : BlockAccessor """ pa.parquet.write_table( _df_to_table(block.to_pandas(), schema=self.schema, index=self.index, dtype=self.dtype), file, **self.pyarrow_additional_kwargs, )