def _iter_pytorch()

in src/datasets/iterable_dataset.py [0:0]


    def _iter_pytorch(self):
        ex_iterable = self._prepare_ex_iterable_for_iteration()
        # Fix for fsspec when using multiprocess to avoid hanging in the ML training loop. (only required for fsspec >= 0.9.0)
        # See https://github.com/fsspec/gcsfs/issues/379
        fsspec.asyn.reset_lock()
        # check if there aren't too many workers
        import torch.utils.data

        worker_info = torch.utils.data.get_worker_info()
        if self._is_main_process() and ex_iterable.num_shards < worker_info.num_workers:
            logger.warning(
                f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.num_shards={ex_iterable.num_shards}). "
                f"Stopping {worker_info.num_workers - ex_iterable.num_shards} dataloader workers."
            )
            logger.info(
                f"To parallelize data loading, we give each process some shards (or data sources) to process. "
                f"Therefore it's unnecessary to have a number of workers greater than dataset.num_shards={ex_iterable.num_shards}. "
                f"To enable more parallelism, please split the dataset in more files than {ex_iterable.num_shards}."
            )
        # split workload
        _log_prefix = f"node#{self._distributed.rank} " if self._distributed else ""
        shards_indices = ex_iterable.split_shard_indices_by_worker(
            num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
        )
        if shards_indices:
            logger.debug(
                f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.num_shards} shards."
            )
            ex_iterable = ex_iterable.shard_data_sources(
                num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
            )
            self._state_dict = {
                "examples_iterable": ex_iterable._init_state_dict(),
                "epoch": self.epoch,
            }
            if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]:
                ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"])

            if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
                formatter = get_formatter(self._formatting.format_type, features=self.features)
                if ex_iterable.iter_arrow:
                    iterator = ex_iterable.iter_arrow()
                else:
                    iterator = _convert_to_arrow(ex_iterable, batch_size=1)
                for key, pa_table in iterator:
                    yield formatter.format_row(pa_table)
                return
            else:
                for key, example in ex_iterable:
                    # no need to format thanks to FormattedExamplesIterable
                    yield example
            logger.debug(
                f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.num_shards} shards."
            )
        else:
            logger.debug(
                f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.num_shards}<{worker_info.num_workers})."
            )