def _iter()

in src/datasets/iterable_dataset.py [0:0]


    def _iter(self):
        current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0
        if self._state_dict and self._state_dict["previous_state"]:
            self.ex_iterable.load_state_dict(self._state_dict["previous_state"])
            num_examples_to_skip = self._state_dict["num_examples_since_previous_state"]
        else:
            num_examples_to_skip = 0
        iterator = iter(self.ex_iterable)

        # We use the same logic as in Dataset.map, but with less features/formatting
        # since they're handled by FormattedExamplesIterable

        if self.formatting:
            formatter = get_formatter(self.formatting.format_type)
            format_dict = formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else None
        else:
            format_dict = None

        def iter_batched_inputs():
            nonlocal current_idx
            for key, example in iterator:
                # If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
                iterator_batch = (
                    iterator
                    if self.batch_size is None or self.batch_size <= 0
                    else islice(iterator, self.batch_size - 1)
                )
                key_examples_list = [(key, example)] + list(iterator_batch)
                keys, examples = zip(*key_examples_list)
                # the new key is the concatenation of the examples keys from the batch
                key = "_".join(str(key) for key in keys)
                if (
                    self.drop_last_batch
                    and self.batch_size is not None
                    and self.batch_size > 0
                    and len(examples) < self.batch_size
                ):  # ignore last batch
                    return
                batch = _examples_to_batch(examples)
                # we need to format here in case we need to stack tensors together
                batch = format_dict(batch) if format_dict else batch
                indices = [current_idx + i for i in range(len(key_examples_list))]
                current_idx += len(indices)
                yield indices, (key, batch)

        def iter_inputs():
            nonlocal current_idx
            for key, example in iterator:
                # If not batched, we can apply the transform and yield the example directly
                # first copy the example, since we might drop some keys
                example = dict(example)
                # no need to do formatting here
                current_idx += 1
                yield current_idx - 1, (key, example)

        def validate_function_output(processed_inputs):
            if self.batched and processed_inputs:
                first_col = next(iter(processed_inputs))
                bad_cols = [
                    col for col in processed_inputs if len(processed_inputs[col]) != len(processed_inputs[first_col])
                ]
                if bad_cols:
                    raise ValueError(
                        f"Column lengths mismatch: columns {bad_cols} have length {[len(processed_inputs[col]) for col in bad_cols]} "
                        f"while {first_col} has length {len(processed_inputs[first_col])}."
                    )

        def prepare_inputs(key_example, indices):
            key, example = key_example
            fn_args = [example] if self.input_columns is None else [example[col] for col in self.input_columns]
            additional_args = ()
            if self.with_indices:
                fn_args += (indices,)
            inputs = dict(example)
            return inputs, fn_args, additional_args, self.fn_kwargs

        def prepare_outputs(key_example, inputs, processed_inputs):
            validate_function_output(processed_inputs)
            # this logic mimics the one in Dataset.map
            if self.remove_columns:
                for c in self.remove_columns:
                    if c in inputs:
                        del inputs[c]
                    if processed_inputs is key_example[1] and c in processed_inputs:
                        del processed_inputs[c]
            transformed_inputs = {**inputs, **processed_inputs}
            # no need to do features decoding here
            return transformed_inputs

        def apply_function(key_example, indices):
            """Utility to apply the function on a selection of columns."""
            inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices)
            processed_inputs = self.function(*fn_args, *additional_args, **fn_kwargs)
            return prepare_outputs(key_example, inputs, processed_inputs)

        async def async_apply_function(key_example, indices):
            """Utility to apply the function on a selection of columns. Same code but async"""
            inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices)
            processed_inputs = await self.function(*fn_args, *additional_args, **fn_kwargs)
            return prepare_outputs(key_example, inputs, processed_inputs)

        tasks: list[asyncio.Task] = []
        if inspect.iscoroutinefunction(self.function):
            try:
                loop = asyncio.get_running_loop()
            except RuntimeError:
                loop = asyncio.new_event_loop()
            self._owned_loops_and_tasks.append((loop, tasks))
        else:
            loop = None

        def iter_outputs():
            nonlocal tasks, loop
            inputs_iterator = iter_batched_inputs() if self.batched else iter_inputs()
            if inspect.iscoroutinefunction(self.function):
                if self._state_dict:
                    previous_state = self.ex_iterable.state_dict()
                    self._state_dict["previous_state"] = previous_state
                    previous_state_task = None
                    previous_state_example_idx = self._state_dict["previous_state_example_idx"]
                indices: Union[list[int], list[list[int]]] = []
                for i, key_example in inputs_iterator:
                    indices.append(i)
                    tasks.append(loop.create_task(async_apply_function(key_example, i)))
                    # keep the total active tasks under a certain number
                    if len(tasks) >= self.max_num_running_async_map_functions_in_parallel:
                        done, pending = loop.run_until_complete(
                            asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
                        )
                        while tasks and len(pending) >= self.max_num_running_async_map_functions_in_parallel:
                            done, pending = loop.run_until_complete(
                                asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
                            )
                    if len(tasks) >= 10 * self.max_num_running_async_map_functions_in_parallel:
                        loop.run_until_complete(tasks[0])
                    # yield finished tasks
                    while tasks and tasks[0].done():
                        i, task = indices.pop(0), tasks.pop(0)
                        yield i, task.result()
                        if self._state_dict and task is previous_state_task:
                            self._state_dict["previous_state"] = previous_state
                            self._state_dict["num_examples_since_previous_state"] = 0
                            self._state_dict["previous_state_example_idx"] = previous_state_example_idx
                            previous_state, previous_state_task = None, None
                    # checkpoint
                    if self._state_dict and previous_state_task is None and tasks:
                        previous_state = self.ex_iterable.state_dict()
                        previous_state_task = tasks[-1]
                        previous_state_example_idx = current_idx
                while tasks:
                    yield indices[0], loop.run_until_complete(tasks[0])
                    indices.pop(0), tasks.pop(0)
            else:
                if self._state_dict:
                    if self.batched:
                        self._state_dict["previous_state"] = self.ex_iterable.state_dict()
                        self._state_dict["num_examples_since_previous_state"] = 0
                        self._state_dict["previous_state_example_idx"] = current_idx
                for i, key_example in inputs_iterator:
                    if self._state_dict:
                        if not self.batched:
                            self._state_dict["previous_state_example_idx"] = current_idx
                    yield i, apply_function(key_example, i)
                    if self._state_dict:
                        if self.batched:
                            self._state_dict["previous_state"] = self.ex_iterable.state_dict()
                            self._state_dict["num_examples_since_previous_state"] = 0
                            self._state_dict["previous_state_example_idx"] = current_idx

        try:
            outputs = iter_outputs()
            if self.batched:
                outputs = (
                    (key, transformed_example)
                    for key, transformed_batch in outputs
                    for transformed_example in _batch_to_examples(transformed_batch)
                )
            for key, transformed_example in outputs:
                if self._state_dict and self._state_dict["previous_state"] is not None:
                    self._state_dict["num_examples_since_previous_state"] += 1
                if num_examples_to_skip > 0:
                    num_examples_to_skip -= 1
                    continue
                yield key, transformed_example
        except (Exception, KeyboardInterrupt):
            if loop:
                logger.debug(f"Canceling {len(tasks)} async tasks.")
                for task in tasks:
                    task.cancel(msg="KeyboardInterrupt")
                try:
                    loop.run_until_complete(asyncio.gather(*tasks))
                except (asyncio.CancelledError, ValueError):
                    logger.debug("Tasks canceled.")
            raise