in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]
def step_barrier_wait(msg, step_count):
barrier_start = datetime.datetime.now()
max_logging.log(
f"STANDALONE DATALOADER : Barrier on host {jax.process_index()} started on step {step_count} at {barrier_start}"
)
jax.experimental.multihost_utils.sync_global_devices(msg)
barrier_end = datetime.datetime.now()
max_logging.log(
f"STANDALONE DATALOADER : Barrier on host {jax.process_index()} completed on step {step_count} at {barrier_end}, lasted {(barrier_end - barrier_start).total_seconds()} seconds"
)