def main()

in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]


def main(argv: Sequence[str]) -> None:
    jax.config.update("jax_cpu_enable_gloo_collectives", True)
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
    pyconfig.initialize(argv)
    config = pyconfig.config
    # validate_train_config(config)
    max_logging.log(f"Found {jax.device_count()} devices.")
    max_logging.log(f"Found {jax.process_count()} processes.")
    max_logging.log(f"Found {jax.devices()} devices.")
    if config.dataset_type == "tfds":
        os.environ["TFDS_DATA_DIR"] = config.dataset_path
    data_load_loop(config)