def step_barrier_wait()

in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]


def step_barrier_wait(msg, step_count):
    barrier_start = datetime.datetime.now()
    max_logging.log(
        f"STANDALONE DATALOADER : Barrier on host {jax.process_index()} started on step {step_count} at {barrier_start}"
    )

    jax.experimental.multihost_utils.sync_global_devices(msg)

    barrier_end = datetime.datetime.now()
    max_logging.log(
        f"STANDALONE DATALOADER : Barrier on host {jax.process_index()} completed on step {step_count} at {barrier_end}, lasted {(barrier_end - barrier_start).total_seconds()} seconds"
    )