in egg/core/distributed.py [0:0]
def maybe_init_distributed(args) -> DistributedContext:
assert not hasattr(
args, "distributed_context"
), "distributed context is already initialized?!"
# default, non-distributed context
context = DistributedContext(
is_distributed=False, rank=0, local_rank=0, world_size=1, mode="none"
)
launch_keys = ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK", "LOCAL_RANK"]
slurm_keys = [
"SLURM_LOCALID",
"SLURM_PROCID",
"SLURM_NTASKS",
"SLURM_NODEID",
"SLURM_JOB_NODELIST",
]
# is it torch.distributed.launch?
if all(key in os.environ for key in launch_keys):
init_method = "env://"
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
context = DistributedContext(
is_distributed=True,
rank=rank,
world_size=world_size,
local_rank=local_rank,
mode="launch",
)
dist.init_process_group(
backend="nccl", init_method=init_method, world_size=world_size, rank=rank
)
# is it slurm?
elif all(key in os.environ for key in slurm_keys):
init_method = "env://"
local_rank = int(os.environ["SLURM_LOCALID"])
rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NTASKS"])
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
)
leader_addr = hostnames.split()[0].decode("utf-8")
os.environ["MASTER_ADDR"] = leader_addr
os.environ["MASTER_PORT"] = str(args.distributed_port)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
if world_size > 1:
# no point in being distributed if it is alone
context = DistributedContext(
is_distributed=True,
rank=rank,
local_rank=local_rank,
world_size=world_size,
mode="slurm",
)
dist.init_process_group(
backend="nccl",
init_method=init_method,
world_size=world_size,
rank=rank,
)
return context