def parquet_data_loader()

in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]


def parquet_data_loader(config):
    batch_size = config.local_batch_size
    worker_id = jax.process_index()

    dataset = ParquetIterableDataset(batch_size=batch_size,
                                     columns=["outputs", "image_base64_str"],
                                     rank=worker_id,
                                     world_size=jax.process_count(),
                                     project_name=os.environ["PROJECT"],
                                     bucket_name=os.environ["BUCKET"],
                                     config=dataflux_iterable_dataset.Config(
                                         num_processes=1,
                                         sort_listing_results=True,
                                         prefix=os.environ["PREFIX"]))
    data_loader = DataLoader(
        dataset=dataset,
        num_workers=config.data_loader_num_workers,
        batch_size=batch_size,
        prefetch_factor=config.prefetch_factor,
    )
    return data_loader