def __iter__()

in oss-torch-connector/osstorchconnector/oss_iterable_dataset.py [0:0]


    def __iter__(self) -> Iterator[Any]:
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:     # single-process data loading, return the full iterator
            log.info("OssIterableDataset get iter (single-process)")
            if self._from_tar and self._shuffle:
                if len(self._chunks) >= 1:
                    chunks = self._chunks
                else:
                    chunks = []
                log.info("OssIterableDataset chunk num: %d", len(chunks))
                worker_iter = self._get_dataset_objects(self._get_client(0, 1), chunks=chunks)
            else:
                worker_iter = self._get_dataset_objects(self._get_client(0, 1))
        else:                       # in a worker process, split workload
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            log.info("OssIterableDataset get iter (multi-process), num_workers: %d, worker id: %d", num_workers, worker_id)
            if self._from_tar and self._shuffle:
                if len(self._chunks) >= num_workers:
                    chunks = [chunk for i, chunk in enumerate(self._chunks) if i % num_workers == worker_id]
                else:
                    chunks = []
                log.info("OssIterableDataset chunk num: %d", len(chunks))
                worker_iter = self._get_dataset_objects(self._get_client(worker_id, num_workers), chunks=chunks)
            else:
                worker_iter = self._get_dataset_objects(self._get_client(worker_id, num_workers))

        return map(self._get_transformed_object, worker_iter)