def setup_mpi()

in train_helpers.py [0:0]


def setup_mpi(H):
    H.mpi_size = mpi_size()
    H.local_rank = local_mpi_rank()
    H.rank = mpi_rank()
    os.environ["RANK"] = str(H.rank)
    os.environ["WORLD_SIZE"] = str(H.mpi_size)
    os.environ["MASTER_PORT"] = str(H.port)
    # os.environ["NCCL_LL_THRESHOLD"] = "0"
    os.environ["MASTER_ADDR"] = MPI.COMM_WORLD.bcast(socket.gethostname(), root=0)
    torch.cuda.set_device(H.local_rank)
    dist.init_process_group(backend='nccl', init_method=f"env://")