in benchmarks/dlrm/ootb/extend_distributed.py [0:0]
def init_distributed(rank=-1, local_rank=-1, size=-1, use_gpu=False, backend=""):
global myreq
global my_rank
global my_size
global my_local_rank
global my_local_size
global a2a_impl
global alltoall_supported
# guess MPI ranks from env (works for IMPI, OMPI and MVAPICH2)
num_mpi_ranks = env2int(
["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"]
)
if backend == "" and num_mpi_ranks > 1:
if torch_ccl and env2int(["CCL_WORKER_COUNT"]) > 0:
backend = "ccl"
elif use_gpu and dist.is_nccl_available():
backend = "nccl"
elif dist.is_mpi_available():
backend = "mpi"
else:
print(
"WARNING: MPI multi-process launch detected but PyTorch MPI backend not available."
)
backend = "gloo"
if backend != "":
# guess Rank and size
if rank == -1:
rank = env2int(
["PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK", "RANK"], 0
)
if size == -1:
size = env2int(
[
"PMI_SIZE",
"OMPI_COMM_WORLD_SIZE",
"MV2_COMM_WORLD_SIZE",
"WORLD_SIZE",
],
1,
)
if not os.environ.get("RANK", None) and rank != -1:
os.environ["RANK"] = str(rank)
if not os.environ.get("WORLD_SIZE", None) and size != -1:
os.environ["WORLD_SIZE"] = str(size)
if not os.environ.get("MASTER_PORT", None):
os.environ["MASTER_PORT"] = "29500"
if not os.environ.get("MASTER_ADDR", None):
local_size = env2int(
[
"MPI_LOCALNRANKS",
"OMPI_COMM_WORLD_LOCAL_SIZE",
"MV2_COMM_WORLD_LOCAL_SIZE",
],
1,
)
if local_size != size and backend != "mpi":
print(
"Warning: Looks like distributed multinode run but MASTER_ADDR env not set, using '127.0.0.1' as default"
)
print(
"If this run hangs, try exporting rank 0's hostname as MASTER_ADDR"
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
if size > 1:
if local_rank == -1:
my_local_rank = env2int(
[
"MPI_LOCALRANKID",
"OMPI_COMM_WORLD_LOCAL_RANK",
"MV2_COMM_WORLD_LOCAL_RANK",
"LOCAL_RANK",
],
0,
)
else:
my_local_rank = local_rank
my_local_size = env2int(
[
"MPI_LOCALNRANKS",
"OMPI_COMM_WORLD_LOCAL_SIZE",
"MV2_COMM_WORLD_LOCAL_SIZE",
],
1,
)
if use_gpu:
if my_local_size > torch.cuda.device_count():
print(
"Not sufficient GPUs available... local_size = %d, ngpus = %d"
% (my_local_size, torch.cuda.device_count())
)
sys.exit(1)
torch.cuda.set_device(my_local_rank)
dist.init_process_group(backend, rank=rank, world_size=size)
my_rank = dist.get_rank()
my_size = dist.get_world_size()
if my_rank == 0:
print("Running on %d ranks using %s backend" % (my_size, backend))
if hasattr(dist, "all_to_all_single"):
try:
t = torch.zeros([4])
if use_gpu:
t = t.cuda()
dist.all_to_all_single(t, t)
alltoall_supported = True
except RuntimeError as err:
print("fail to enable all_to_all_single primitive: %s" % err)
if a2a_impl == "alltoall" and alltoall_supported == False:
print(
"Requested DLRM_ALLTOALL_IMPL=%s but backend %s does not support it, use scatter/gather based alltoall"
% (a2a_impl, backend)
)
a2a_impl = "scatter"
if a2a_impl != "":
print("Using DLRM_ALLTOALL_IMPL=%s" % a2a_impl)
else:
my_rank = 0
my_size = 1
my_local_rank = 0
my_local_size = 1
print_all(
"world size: %d, current rank: %d, local rank: %d"
% (my_size, my_rank, my_local_rank)
)
myreq = Request()