in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
# validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
max_logging.log(f"Found {jax.process_count()} processes.")
max_logging.log(f"Found {jax.devices()} devices.")
if config.dataset_type == "tfds":
os.environ["TFDS_DATA_DIR"] = config.dataset_path
data_load_loop(config)