in oss-torch-connector/osstorchconnector/oss_iterable_dataset.py [0:0]
def __iter__(self) -> Iterator[Any]:
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
log.info("OssIterableDataset get iter (single-process)")
if self._from_tar and self._shuffle:
if len(self._chunks) >= 1:
chunks = self._chunks
else:
chunks = []
log.info("OssIterableDataset chunk num: %d", len(chunks))
worker_iter = self._get_dataset_objects(self._get_client(0, 1), chunks=chunks)
else:
worker_iter = self._get_dataset_objects(self._get_client(0, 1))
else: # in a worker process, split workload
num_workers = worker_info.num_workers
worker_id = worker_info.id
log.info("OssIterableDataset get iter (multi-process), num_workers: %d, worker id: %d", num_workers, worker_id)
if self._from_tar and self._shuffle:
if len(self._chunks) >= num_workers:
chunks = [chunk for i, chunk in enumerate(self._chunks) if i % num_workers == worker_id]
else:
chunks = []
log.info("OssIterableDataset chunk num: %d", len(chunks))
worker_iter = self._get_dataset_objects(self._get_client(worker_id, num_workers), chunks=chunks)
else:
worker_iter = self._get_dataset_objects(self._get_client(worker_id, num_workers))
return map(self._get_transformed_object, worker_iter)