in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]
def parquet_data_loader(config):
batch_size = config.local_batch_size
worker_id = jax.process_index()
dataset = ParquetIterableDataset(batch_size=batch_size,
columns=["outputs", "image_base64_str"],
rank=worker_id,
world_size=jax.process_count(),
project_name=os.environ["PROJECT"],
bucket_name=os.environ["BUCKET"],
config=dataflux_iterable_dataset.Config(
num_processes=1,
sort_listing_results=True,
prefix=os.environ["PREFIX"]))
data_loader = DataLoader(
dataset=dataset,
num_workers=config.data_loader_num_workers,
batch_size=batch_size,
prefetch_factor=config.prefetch_factor,
)
return data_loader