in src/datasets/iterable_dataset.py [0:0]
def _iter_pytorch(self):
ex_iterable = self._prepare_ex_iterable_for_iteration()
# Fix for fsspec when using multiprocess to avoid hanging in the ML training loop. (only required for fsspec >= 0.9.0)
# See https://github.com/fsspec/gcsfs/issues/379
fsspec.asyn.reset_lock()
# check if there aren't too many workers
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if self._is_main_process() and ex_iterable.num_shards < worker_info.num_workers:
logger.warning(
f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.num_shards={ex_iterable.num_shards}). "
f"Stopping {worker_info.num_workers - ex_iterable.num_shards} dataloader workers."
)
logger.info(
f"To parallelize data loading, we give each process some shards (or data sources) to process. "
f"Therefore it's unnecessary to have a number of workers greater than dataset.num_shards={ex_iterable.num_shards}. "
f"To enable more parallelism, please split the dataset in more files than {ex_iterable.num_shards}."
)
# split workload
_log_prefix = f"node#{self._distributed.rank} " if self._distributed else ""
shards_indices = ex_iterable.split_shard_indices_by_worker(
num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
)
if shards_indices:
logger.debug(
f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.num_shards} shards."
)
ex_iterable = ex_iterable.shard_data_sources(
num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
)
self._state_dict = {
"examples_iterable": ex_iterable._init_state_dict(),
"epoch": self.epoch,
}
if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]:
ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"])
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
formatter = get_formatter(self._formatting.format_type, features=self.features)
if ex_iterable.iter_arrow:
iterator = ex_iterable.iter_arrow()
else:
iterator = _convert_to_arrow(ex_iterable, batch_size=1)
for key, pa_table in iterator:
yield formatter.format_row(pa_table)
return
else:
for key, example in ex_iterable:
# no need to format thanks to FormattedExamplesIterable
yield example
logger.debug(
f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.num_shards} shards."
)
else:
logger.debug(
f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.num_shards}<{worker_info.num_workers})."
)