in src/datasets/packaged_modules/parquet/parquet.py [0:0]
def _generate_tables(self, files):
if self.config.features is not None and self.config.columns is not None:
if sorted(field.name for field in self.info.features.arrow_schema) != sorted(self.config.columns):
raise ValueError(
f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.info.features}'"
)
filter_expr = (
pq.filters_to_expression(self.config.filters)
if isinstance(self.config.filters, list)
else self.config.filters
)
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
parquet_fragment = ds.ParquetFileFormat().make_fragment(f)
if parquet_fragment.row_groups:
batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows
try:
for batch_idx, record_batch in enumerate(
parquet_fragment.to_batches(
batch_size=batch_size,
columns=self.config.columns,
filter=filter_expr,
batch_readahead=0,
fragment_readahead=0,
)
):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise