def _iter_arrow()

in src/datasets/iterable_dataset.py [0:0]


    def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[tuple[Key, pa.Table]]:
        formatter: TableFormatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter()
        if self.ex_iterable.iter_arrow:
            iterator = self.ex_iterable.iter_arrow()
        else:
            iterator = _convert_to_arrow(
                self.ex_iterable,
                batch_size=self.batch_size if self.batched else 1,
                drop_last_batch=self.drop_last_batch,
            )
        if self._state_dict and self._state_dict["previous_state"]:
            self.ex_iterable.load_state_dict(self._state_dict["previous_state"])
            num_examples_to_skip = self._state_dict["num_examples_since_previous_state"]
        else:
            num_examples_to_skip = 0
        if self._state_dict and max_chunksize is not None:
            self._state_dict["previous_state"] = self.ex_iterable.state_dict()
            self._state_dict["num_examples_since_previous_state"] = 0
        current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0
        for key, pa_table in iterator:
            if (
                self.batched
                and self.batch_size is not None
                and len(pa_table) < self.batch_size
                and self.drop_last_batch
            ):
                return
            # first build the batch
            function_args = (
                [formatter.format_batch(pa_table)]
                if self.input_columns is None
                else [pa_table[col] for col in self.input_columns]
            )
            if self.with_indices:
                if self.batched:
                    function_args.append([current_idx + i for i in range(len(pa_table))])
                else:
                    function_args.append(current_idx)
            # then apply the transform
            output = self.function(*function_args, **self.fn_kwargs)
            output_table = _table_output_to_arrow(output)
            if not isinstance(output_table, pa.Table):
                raise TypeError(
                    f"Provided `function` which is applied to {formatter.table_type} returns a variable of type "
                    f"{type(output)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset."
                )
            # we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts
            # then remove the unwanted columns
            if self.remove_columns:
                for column in self.remove_columns:
                    if column in output_table.column_names:
                        output_table = output_table.remove_column(output_table.column_names.index(column))
            # return output
            if max_chunksize is None:
                current_idx += len(pa_table)
                if self._state_dict:
                    self._state_dict["previous_state_example_idx"] += len(pa_table)
                yield key, output_table
            else:
                for i, pa_subtable in enumerate(output_table.to_reader(max_chunksize=max_chunksize)):
                    current_idx += 1
                    if self._state_dict:
                        self._state_dict["num_examples_since_previous_state"] += 1
                    if num_examples_to_skip > 0:
                        num_examples_to_skip -= 1
                        continue
                    yield f"{key}_{i}", pa_subtable
                if self._state_dict:
                    self._state_dict["previous_state"] = self.ex_iterable.state_dict()
                    self._state_dict["num_examples_since_previous_state"] = 0
                    self._state_dict["previous_state_example_idx"] += len(pa_table)