in src/datasets/iterable_dataset.py [0:0]
def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[tuple[Key, pa.Table]]:
formatter: TableFormatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter()
if self.ex_iterable.iter_arrow:
iterator = self.ex_iterable.iter_arrow()
else:
iterator = _convert_to_arrow(
self.ex_iterable,
batch_size=self.batch_size if self.batched else 1,
drop_last_batch=self.drop_last_batch,
)
if self._state_dict and self._state_dict["previous_state"]:
self.ex_iterable.load_state_dict(self._state_dict["previous_state"])
num_examples_to_skip = self._state_dict["num_examples_since_previous_state"]
else:
num_examples_to_skip = 0
if self._state_dict and max_chunksize is not None:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0
for key, pa_table in iterator:
if (
self.batched
and self.batch_size is not None
and len(pa_table) < self.batch_size
and self.drop_last_batch
):
return
# first build the batch
function_args = (
[formatter.format_batch(pa_table)]
if self.input_columns is None
else [pa_table[col] for col in self.input_columns]
)
if self.with_indices:
if self.batched:
function_args.append([current_idx + i for i in range(len(pa_table))])
else:
function_args.append(current_idx)
# then apply the transform
output = self.function(*function_args, **self.fn_kwargs)
output_table = _table_output_to_arrow(output)
if not isinstance(output_table, pa.Table):
raise TypeError(
f"Provided `function` which is applied to {formatter.table_type} returns a variable of type "
f"{type(output)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset."
)
# we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts
# then remove the unwanted columns
if self.remove_columns:
for column in self.remove_columns:
if column in output_table.column_names:
output_table = output_table.remove_column(output_table.column_names.index(column))
# return output
if max_chunksize is None:
current_idx += len(pa_table)
if self._state_dict:
self._state_dict["previous_state_example_idx"] += len(pa_table)
yield key, output_table
else:
for i, pa_subtable in enumerate(output_table.to_reader(max_chunksize=max_chunksize)):
current_idx += 1
if self._state_dict:
self._state_dict["num_examples_since_previous_state"] += 1
if num_examples_to_skip > 0:
num_examples_to_skip -= 1
continue
yield f"{key}_{i}", pa_subtable
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] += len(pa_table)