in train_helpers.py [0:0]
def setup_mpi(H):
H.mpi_size = mpi_size()
H.local_rank = local_mpi_rank()
H.rank = mpi_rank()
os.environ["RANK"] = str(H.rank)
os.environ["WORLD_SIZE"] = str(H.mpi_size)
os.environ["MASTER_PORT"] = str(H.port)
# os.environ["NCCL_LL_THRESHOLD"] = "0"
os.environ["MASTER_ADDR"] = MPI.COMM_WORLD.bcast(socket.gethostname(), root=0)
torch.cuda.set_device(H.local_rank)
dist.init_process_group(backend='nccl', init_method=f"env://")