maga_transformer/distribute/gang_test_util.py (40 lines of code) (raw):

import time import logging from datetime import timedelta from torch.distributed.rendezvous import rendezvous def create_store(master_url, world_rank, world_size): rendezvous_iterator = rendezvous( master_url, world_rank, world_size ) store, rank, world_size = next(rendezvous_iterator) logging.info(f"store: {store}, rank: {rank}, world_size: {world_size}") store.set_timeout(timedelta(seconds=10)) return store def store_based_barrier(rank: int, world_size: int, store, timeout: timedelta): store_key = "{}:{}".format("store_barrier", 0) store.add(store_key, 1) logging.info("Added key: {} to store for rank: {}".format(store_key, rank)) # Now wait for all workers to check in with the store. # Use 'add' instead of 'get' since for some store implementations 'add' # doesn't work well with 'get'. Ideally the store implementations should # be fixed, but for backward compatiblity reasons it is risky to change # the store implementations. Once, we completely migrate away from these # legacy stores, we can use 'get' here instead. worker_count = store.add(store_key, 0) start = time.time() log_time = time.time() while worker_count != world_size: time.sleep(0.01) worker_count = store.add(store_key, 0) # Print status periodically to keep track. if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10): logging.info( "Waiting in store based barrier to initialize process group for " "rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format( rank, store_key, world_size, worker_count, timeout ) ) log_time = time.time() if timedelta(seconds=(time.time() - start)) > timeout: raise RuntimeError( "Timed out initializing process group in store based barrier on " "rank: {}, for key: {} (world_size={}, worker_count={}, timeout={})".format( rank, store_key, world_size, worker_count, timeout ) ) logging.info( f"Rank {rank}: Completed store-based barrier for key:{store_key} with {world_size} nodes." )