in dataflux_pytorch/dataflux_iterable_dataset.py [0:0]
def __iter__(self):
worker_info = data.get_worker_info()
if worker_info is None:
# Single-process data loading.
yield from (self.data_format_fn(bytes_content) for bytes_content in
dataflux_core.download.dataflux_download_lazy(
project_name=self.project_name,
bucket_name=self.bucket_name,
objects=self.objects,
storage_client=self.storage_client,
dataflux_download_optimization_params=self.
dataflux_download_optimization_params,
retry_config=self.config.download_retry_config,
))
else:
# Multi-process data loading. Split the workload among workers.
# Ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset.
per_worker = int(
math.ceil(len(self.objects) / float(worker_info.num_workers)))
worker_id = worker_info.id
start = worker_id * per_worker
end = min(start + per_worker, len(self.objects))
yield from (self.data_format_fn(bytes_content) for bytes_content in
dataflux_core.download.dataflux_download_lazy(
project_name=self.project_name,
bucket_name=self.bucket_name,
objects=self.objects[start:end],
storage_client=self.storage_client,
dataflux_download_optimization_params=self.
dataflux_download_optimization_params,
retry_config=self.config.download_retry_config,
))