awswrangler/distributed/ray/datasources/filename_provider.py (37 lines of code) (raw):

"""Ray DefaultFilenameProvider Module.""" from __future__ import annotations from typing import Any from ray.data.block import Block from ray.data.datasource.filename_provider import FilenameProvider from awswrangler.s3._write import _COMPRESSION_2_EXT class _DefaultFilenameProvider(FilenameProvider): def __init__( self, file_format: str, dataset_uuid: str | None = None, compression: str | None = None, bucket_id: int | None = None, ): self._dataset_uuid = dataset_uuid self._file_format = file_format self._compression = compression self._bucket_id = bucket_id def get_filename_for_block( self, block: Block, task_index: int, block_index: int, ) -> str: file_id = f"{task_index:06}_{block_index:06}" return self._generate_filename(file_id) def get_filename_for_row(self, row: dict[str, Any], task_index: int, block_index: int, row_index: int) -> str: file_id = f"{task_index:06}_{block_index:06}_{row_index:06}" return self._generate_filename(file_id) def _generate_filename(self, file_id: str) -> str: filename = "" if self._dataset_uuid is not None: filename += f"{self._dataset_uuid}_" filename += f"{file_id}" if self._bucket_id is not None: filename += f"_bucket-{self._bucket_id:05d}" filename += f".{self._file_format}{_COMPRESSION_2_EXT.get(self._compression)}" return filename