in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]
def __iter__(self):
"""
Overriding the __iter__ function allows batch size to be used to sub-divide the contents of
individual parquet files.
"""
worker_info = data.get_worker_info()
files_per_node = math.ceil(len(self.objects) / self.world_size)
start_point = self.rank * files_per_node
end_point = start_point + files_per_node
if worker_info is None:
print("Single-process data loading detected", flush=True)
for bytes_content in dataflux_iterable_dataset.dataflux_core.download.dataflux_download_lazy(
project_name=self.project_name,
bucket_name=self.bucket_name,
objects=self.objects[start_point:end_point],
storage_client=self.storage_client,
dataflux_download_optimization_params=self.
dataflux_download_optimization_params,
retry_config=self.config.download_retry_config,
):
table = self.data_format_fn(bytes_content)
for batch in table.iter_batches(batch_size=self.batch_size,
columns=self.columns):
yield from batch.to_pylist()
else:
# Multi-process data loading. Split the workload among workers.
# Ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset.
# For the purpose of this example, this split is only performed at the file level.
# This means that (num_nodes * num_workers) should be less than or equal to filecount.
per_worker = int(
math.ceil(files_per_node / float(worker_info.num_workers)))
worker_id = worker_info.id
start = worker_id * per_worker + start_point
end = min(start + per_worker, end_point)
max_logging.log(
f"-----> Worker {self.rank}.{worker_id} downloading {self.objects[start:end]}\n"
)
for bytes_content in dataflux_iterable_dataset.dataflux_core.download.dataflux_download_lazy(
project_name=self.project_name,
bucket_name=self.bucket_name,
objects=self.objects[start:end],
storage_client=self.storage_client,
dataflux_download_optimization_params=self.
dataflux_download_optimization_params,
retry_config=self.config.download_retry_config,
):
table = self.data_format_fn(bytes_content)
for batch in table.iter_batches(batch_size=self.batch_size,
columns=self.columns):
yield from batch.to_pylist()