in fairscale/utils/testing.py [0:0]
def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
"""
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
tests to be run concurrently.
Return false if not enough GPUs present in the system.
.. warning: This limits the usecase to all ranks being on the same node
"""
try:
torch.distributed.rpc.shutdown()
except Exception:
pass
print(f"dist init r={rank}, world={world_size}")
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
url = "file://" + filename
url_rpc = "file://" + filename_rpc
if torch_version() >= (1, 6, 0):
backend = "nccl" if torch.cuda.is_available() else "gloo"
if backend == "nccl" and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
return False
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
tp_options = {"init_method": url_rpc}
# Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1
if torch_version() == (1, 8, 0):
if torch.cuda.is_available():
# Workaround for https://github.com/pytorch/pytorch/issues/53844
tp_options["_transports"] = ["ibv", "uv"] # type: ignore
else:
# Workaround for https://github.com/pytorch/pytorch/issues/54266
tp_options["_channels"] = ["mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic"] # type: ignore
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(**tp_options),
)
else:
if world_size > 1:
# TensorPipe is not available in Torch 1.5
rpc.init_rpc(
name=f"Test{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(init_method=url_rpc),
)
elif torch.cuda.is_available():
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
else:
return False
if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count())
return True