awswrangler/distributed/ray/datasources/arrow_parquet_datasource.py (402 lines of code) (raw):

"""Ray ArrowParquetDatasource Module. This module is pulled from Ray's [ParquetDatasource] (https://github.com/ray-project/ray/blob/ray-2.9.0/python/ray/data/datasource/parquet_datasource.py) with a few changes and customized to ensure compatibility with AWS SDK for pandas behavior. Changes from the original implementation, are documented in the comments and marked with (AWS SDK for pandas) prefix. """ from __future__ import annotations import logging from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, Callable, Iterator, Literal, ) import numpy as np # fs required to implicitly trigger S3 subsystem initialization import pyarrow.fs import ray from ray import cloudpickle from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import _is_local_scheme from ray.data.block import Block from ray.data.context import DataContext from ray.data.datasource import Datasource from ray.data.datasource.datasource import ReadTask from ray.data.datasource.file_meta_provider import ( DefaultFileMetadataProvider, _handle_read_os_error, ) from ray.data.datasource.parquet_meta_provider import ( ParquetMetadataProvider, ) from ray.data.datasource.partitioning import PathPartitionFilter from ray.data.datasource.path_util import ( _has_file_extension, _resolve_paths_and_filesystem, ) from ray.util.annotations import PublicAPI from awswrangler import exceptions from awswrangler._arrow import _add_table_partitions if TYPE_CHECKING: import pyarrow from pyarrow.dataset import ParquetFileFragment _logger: logging.Logger = logging.getLogger(__name__) # The number of rows to read per batch. This is sized to generate 10MiB batches # for rows about 1KiB in size. PARQUET_READER_ROW_BATCH_SIZE = 10_000 FILE_READING_RETRY = 8 # The default size multiplier for reading Parquet data source in Arrow. # Parquet data format is encoded with various encoding techniques (such as # dictionary, RLE, delta), so Arrow in-memory representation uses much more memory # compared to Parquet encoded representation. Parquet file statistics only record # encoded (i.e. uncompressed) data size information. # # To estimate real-time in-memory data size, Datasets will try to estimate the # correct inflation ratio from Parquet to Arrow, using this constant as the default # value for safety. See https://github.com/ray-project/ray/pull/26516 for more context. PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT = 5 # The lower bound size to estimate Parquet encoding ratio. PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND = 2 # The percentage of files (1% by default) to be sampled from the dataset to estimate # Parquet encoding ratio. PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO = 0.01 # The minimal and maximal number of file samples to take from the dataset to estimate # Parquet encoding ratio. # This is to restrict `PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO` within the # proper boundary. PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES = 2 PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES = 10 # The number of rows to read from each file for sampling. Try to keep it low to avoid # reading too much data into memory. PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024 @dataclass(frozen=True) class _SampleInfo: actual_bytes_per_row: int | None estimated_bytes_per_row: int | None # TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a # raw pyarrow file fragment causes S3 network calls. class _SerializedFragment: def __init__(self, frag: "ParquetFileFragment"): self._data = cloudpickle.dumps( # type: ignore[no-untyped-call] (frag.format, frag.path, frag.filesystem, frag.partition_expression) ) def deserialize(self) -> "ParquetFileFragment": # Implicitly trigger S3 subsystem initialization by importing # pyarrow.fs. import pyarrow.fs # noqa: F401 (file_format, path, filesystem, partition_expression) = cloudpickle.loads(self._data) return file_format.make_fragment(path, filesystem, partition_expression) # Visible for test mocking. def _deserialize_fragments( serialized_fragments: list[_SerializedFragment], ) -> list["pyarrow._dataset.ParquetFileFragment"]: return [p.deserialize() for p in serialized_fragments] # This retry helps when the upstream datasource is not able to handle # overloaded read request or failed with some retriable failures. # For example when reading data from HA hdfs service, hdfs might # lose connection for some unknown reason expecially when # simutaneously running many hyper parameter tuning jobs # with ray.data parallelism setting at high value like the default 200 # Such connection failure can be restored with some waiting and retry. def _deserialize_fragments_with_retry( serialized_fragments: list[_SerializedFragment], ) -> list["pyarrow._dataset.ParquetFileFragment"]: min_interval: float = 0 final_exception: Exception | None = None for i in range(FILE_READING_RETRY): try: return _deserialize_fragments(serialized_fragments) except Exception as e: import random import time retry_timing = "" if i == FILE_READING_RETRY - 1 else (f"Retry after {min_interval} sec. ") log_only_show_in_1st_retry = ( "" if i else ( f"If earlier read attempt threw certain Exception" f", it may or may not be an issue depends on these retries " f"succeed or not. serialized_fragments:{serialized_fragments}" ) ) _logger.exception( f"{i + 1}th attempt to deserialize ParquetFileFragment failed. " f"{retry_timing}" f"{log_only_show_in_1st_retry}" ) if not min_interval: # to make retries of different process hit hdfs server # at slightly different time min_interval = 1 + random.random() # exponential backoff at # 1, 2, 4, 8, 16, 32, 64 time.sleep(min_interval) min_interval = min_interval * 2 final_exception = e raise final_exception # type: ignore[misc] @PublicAPI class ArrowParquetDatasource(Datasource): """Parquet datasource, for reading Parquet files. The primary difference from ParquetBaseDatasource is that this uses PyArrow's `ParquetDataset` abstraction for dataset reads, and thus offers automatic Arrow dataset schema inference and row count collection at the cost of some potential performance and/or compatibility penalties. """ def __init__( # noqa: PLR0912,PLR0915 self, paths: str | list[str], path_root: str, *, arrow_parquet_args: dict[str, Any] | None = None, _block_udf: Callable[[Block], Block] | None = None, filesystem: "pyarrow.fs.FileSystem" | None = None, meta_provider: ParquetMetadataProvider = ParquetMetadataProvider(), partition_filter: PathPartitionFilter | None = None, shuffle: Literal["files"] | None = None, include_paths: bool = False, file_extensions: list[str] | None = None, ): if arrow_parquet_args is None: arrow_parquet_args = {} import pyarrow as pa import pyarrow.parquet as pq self._supports_distributed_reads = not _is_local_scheme(paths) if not self._supports_distributed_reads and ray.util.client.ray.is_connected(): # type: ignore[no-untyped-call] raise ValueError( "Because you're using Ray Client, read tasks scheduled on the Ray " "cluster can't access your local files. To fix this issue, store " "files in cloud storage or a distributed filesystem like NFS." ) self._local_scheduling = None if not self._supports_distributed_reads: from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy self._local_scheduling = NodeAffinitySchedulingStrategy(ray.get_runtime_context().get_node_id(), soft=False) paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem) # HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet # files. To avoid this, we expand the input paths with the default metadata # provider and then apply the partition filter or file extensions. if partition_filter is not None or file_extensions is not None: default_meta_provider = DefaultFileMetadataProvider() expanded_paths, _ = map(list, zip(*default_meta_provider.expand_paths(paths, filesystem))) paths = list(expanded_paths) if partition_filter is not None: paths = partition_filter(paths) if file_extensions is not None: paths = [path for path in paths if _has_file_extension(path, file_extensions)] filtered_paths = set(expanded_paths) - set(paths) if filtered_paths: _logger.info(f"Filtered out {len(filtered_paths)} paths") elif len(paths) == 1: paths = paths[0] schema = arrow_parquet_args.pop("schema") columns = arrow_parquet_args.pop("columns") dataset_kwargs = arrow_parquet_args.pop("dataset_kwargs", {}) try: pq_ds = pq.ParquetDataset( paths, **dataset_kwargs, filesystem=filesystem, ) except OSError as e: _handle_read_os_error(e, paths) if schema is None: schema = pq_ds.schema if columns: schema = pa.schema([schema.field(column) for column in columns], schema.metadata) if _block_udf is not None: # Try to infer dataset schema by passing dummy table through UDF. dummy_table = schema.empty_table() try: inferred_schema = _block_udf(dummy_table).schema inferred_schema = inferred_schema.with_metadata(schema.metadata) except Exception: _logger.debug( "Failed to infer schema of dataset by passing dummy table " "through UDF due to the following exception:", exc_info=True, ) inferred_schema = schema else: inferred_schema = schema try: prefetch_remote_args = {} if self._local_scheduling: prefetch_remote_args["scheduling_strategy"] = self._local_scheduling self._metadata = meta_provider.prefetch_file_metadata(pq_ds.fragments, **prefetch_remote_args) or [] except OSError as e: _handle_read_os_error(e, paths) except pa.ArrowInvalid as ex: if "Parquet file size is 0 bytes" in str(ex): raise exceptions.InvalidFile(f"Invalid Parquet file. {str(ex)}") raise # NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected # network calls when `_ParquetDatasourceReader` is serialized. See # `_SerializedFragment()` implementation for more details. self._pq_fragments = [_SerializedFragment(p) for p in pq_ds.fragments] self._pq_paths = [p.path for p in pq_ds.fragments] self._meta_provider = meta_provider self._inferred_schema = inferred_schema self._block_udf = _block_udf self._columns = columns self._schema = schema self._arrow_parquet_args = arrow_parquet_args self._file_metadata_shuffler = None self._include_paths = include_paths self._path_root = path_root if shuffle == "files": self._file_metadata_shuffler = np.random.default_rng() sample_infos = self._sample_fragments() self._encoding_ratio = _estimate_files_encoding_ratio(sample_infos) self._default_read_batch_size_rows = _estimate_default_read_batch_size_rows(sample_infos) def estimate_inmemory_data_size(self) -> int | None: """Return an estimate of the Parquet files encoding ratio. To avoid OOMs, it is safer to return an over-estimate than an underestimate. """ total_size: int = 0 for file_metadata in self._metadata: total_size += file_metadata.total_byte_size return total_size * self._encoding_ratio # type: ignore[return-value] def get_read_tasks(self, parallelism: int) -> list[ReadTask]: """Override the base class FileBasedDatasource.get_read_tasks(). Required in order to leverage pyarrow's ParquetDataset abstraction, which simplifies partitioning logic. """ pq_metadata = self._metadata if len(pq_metadata) < len(self._pq_fragments): # Pad `pq_metadata` to be same length of `self._pq_fragments`. # This can happen when no file metadata being prefetched. pq_metadata += [None] * (len(self._pq_fragments) - len(pq_metadata)) # type: ignore[list-item] if self._file_metadata_shuffler is not None: files_metadata = list(zip(self._pq_fragments, self._pq_paths, pq_metadata)) shuffled_files_metadata = [ files_metadata[i] for i in self._file_metadata_shuffler.permutation(len(files_metadata)) ] pq_fragments, pq_paths, pq_metadata = list(map(list, zip(*shuffled_files_metadata))) else: pq_fragments, pq_paths = ( self._pq_fragments, self._pq_paths, ) read_tasks = [] for fragments, paths, metadata in zip( # type: ignore[var-annotated] np.array_split(pq_fragments, parallelism), np.array_split(pq_paths, parallelism), np.array_split(pq_metadata, parallelism), # type: ignore[arg-type] ): if len(fragments) <= 0: continue meta = self._meta_provider( paths, # type: ignore[arg-type] self._inferred_schema, num_fragments=len(fragments), prefetched_metadata=metadata, ) # If there is a filter operation, reset the calculated row count, # since the resulting row count is unknown. if self._arrow_parquet_args.get("filter") is not None: meta.num_rows = None if meta.size_bytes is not None: meta.size_bytes = int(meta.size_bytes * self._encoding_ratio) ( block_udf, arrow_parquet_args, default_read_batch_size_rows, columns, schema, path_root, include_paths, ) = ( self._block_udf, self._arrow_parquet_args, self._default_read_batch_size_rows, self._columns, self._schema, self._path_root, self._include_paths, ) read_tasks.append( ReadTask( lambda f=fragments: _read_fragments( # type: ignore[misc] block_udf, arrow_parquet_args, default_read_batch_size_rows, columns, schema, path_root, f, include_paths, ), meta, ) ) return read_tasks def _sample_fragments(self) -> list[_SampleInfo]: # Sample a few rows from Parquet files to estimate the encoding ratio. # Launch tasks to sample multiple files remotely in parallel. # Evenly distributed to sample N rows in i-th row group in i-th file. # TODO(ekl/cheng) take into account column pruning. num_files = len(self._pq_fragments) num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO) min_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files) max_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MAX_NUM_SAMPLES, num_files) num_samples = max(min(num_samples, max_num_samples), min_num_samples) # Evenly distributed to choose which file to sample, to avoid biased prediction # if data is skewed. file_samples = [ self._pq_fragments[idx] for idx in np.linspace(0, num_files - 1, num_samples).astype(int).tolist() ] sample_fragment = cached_remote_fn(_sample_fragment) futures = [] scheduling = self._local_scheduling or "SPREAD" for sample in file_samples: # Sample the first rows batch in i-th file. # Use SPREAD scheduling strategy to avoid packing many sampling tasks on # same machine to cause OOM issue, as sampling can be memory-intensive. futures.append( sample_fragment.options(scheduling_strategy=scheduling).remote( self._columns, self._schema, sample, ) ) sample_bar = ProgressBar(name="Parquet Files Sample", total=len(futures), unit="file samples") sample_infos = sample_bar.fetch_until_complete(futures) sample_bar.close() # type: ignore[no-untyped-call] return sample_infos def get_name(self) -> str: """Return a human-readable name for this datasource. This will be used as the names of the read tasks. Note: overrides the base `ParquetBaseDatasource` method. """ return "Parquet" @property def supports_distributed_reads(self) -> bool: """If ``False``, only launch read tasks on the driver's node.""" return self._supports_distributed_reads def _read_fragments( block_udf: Callable[[Block], Block] | None, arrow_parquet_args: Any, default_read_batch_size_rows: float, columns: list[str] | None, schema: type | "pyarrow.lib.Schema" | None, path_root: str | None, serialized_fragments: list[_SerializedFragment], include_paths: bool, ) -> Iterator["pyarrow.Table"]: # This import is necessary to load the tensor extension type. from ray.data.extensions.tensor_extension import ArrowTensorType # type: ignore[attr-defined] # noqa # Deserialize after loading the filesystem class. fragments: list["pyarrow._dataset.ParquetFileFragment"] = _deserialize_fragments_with_retry(serialized_fragments) # Ensure that we're reading at least one dataset fragment. assert len(fragments) > 0 import pyarrow as pa from pyarrow.dataset import _get_partition_keys _logger.debug(f"Reading {len(fragments)} parquet fragments") use_threads = arrow_parquet_args.pop("use_threads", False) batch_size = arrow_parquet_args.pop("batch_size", default_read_batch_size_rows) for fragment in fragments: part = _get_partition_keys(fragment.partition_expression) batches = fragment.to_batches( use_threads=use_threads, columns=columns, schema=schema, batch_size=batch_size, **arrow_parquet_args, ) for batch in batches: # (AWS SDK for pandas) Table creation is wrapped inside _add_table_partitions # to add columns with partition values when dataset=True and cast them to categorical table = _add_table_partitions( table=pa.Table.from_batches([batch], schema=schema), path=f"s3://{fragment.path}", path_root=path_root, ) if part: for col, value in part.items(): if columns and col not in columns: continue table = table.set_column( table.schema.get_field_index(col), col, pa.array([value] * len(table)), ) if include_paths: table = table.append_column("path", [[fragment.path]] * len(table)) # If the table is empty, drop it. if table.num_rows > 0: if block_udf is not None: yield block_udf(table) else: yield table def _sample_fragment( columns: list[str] | None, schema: type | "pyarrow.lib.Schema" | None, file_fragment: _SerializedFragment, ) -> _SampleInfo: # Sample the first rows batch from file fragment `serialized_fragment`. fragment = _deserialize_fragments_with_retry([file_fragment])[0] # Only sample the first row group. fragment = fragment.subset(row_group_ids=[0]) batch_size = max(min(fragment.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1) # Use the batch_size calculated above, and ignore the one specified by user if set. # This is to avoid sampling too few or too many rows. batches = fragment.to_batches( columns=columns, schema=schema, batch_size=batch_size, ) # Use first batch in-memory size for estimation. try: batch = next(batches) except StopIteration: sample_data = _SampleInfo(actual_bytes_per_row=None, estimated_bytes_per_row=None) else: if batch.num_rows > 0: metadata = fragment.metadata total_size = 0 for idx in range(metadata.num_row_groups): total_size += metadata.row_group(idx).total_byte_size sample_data = _SampleInfo( actual_bytes_per_row=batch.nbytes / batch.num_rows, estimated_bytes_per_row=total_size / metadata.num_rows, ) else: sample_data = _SampleInfo(actual_bytes_per_row=None, estimated_bytes_per_row=None) return sample_data def _estimate_files_encoding_ratio(sample_infos: list[_SampleInfo]) -> float: """Return an estimate of the Parquet files encoding ratio. To avoid OOMs, it is safer to return an over-estimate than an underestimate. """ if not DataContext.get_current().decoding_size_estimation: return PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT def compute_encoding_ratio(sample_info: _SampleInfo) -> float: if sample_info.actual_bytes_per_row is None or sample_info.estimated_bytes_per_row is None: return PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND else: return sample_info.actual_bytes_per_row / sample_info.estimated_bytes_per_row ratio = np.mean(list(map(compute_encoding_ratio, sample_infos))) _logger.debug(f"Estimated Parquet encoding ratio from sampling is {ratio}.") return max(ratio, PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND) # type: ignore[return-value] def _estimate_default_read_batch_size_rows(sample_infos: list[_SampleInfo]) -> int: def compute_batch_size_rows(sample_info: _SampleInfo) -> int: if sample_info.actual_bytes_per_row is None: return PARQUET_READER_ROW_BATCH_SIZE else: max_parquet_reader_row_batch_size_bytes = DataContext.get_current().target_max_block_size // 10 return max( 1, min( PARQUET_READER_ROW_BATCH_SIZE, max_parquet_reader_row_batch_size_bytes // sample_info.actual_bytes_per_row, ), ) return np.mean(list(map(compute_batch_size_rows, sample_infos))) # type: ignore[return-value]