in mmf/utils/distributed.py [0:0]
def infer_init_method(config):
if config.distributed.init_method is not None:
return
# support torch.distributed.launch
if all(
key in os.environ
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
):
config.distributed.init_method = "env://"
config.distributed.world_size = int(os.environ["WORLD_SIZE"])
config.distributed.rank = int(os.environ["RANK"])
config.distributed.no_spawn = True
# we can determine the init method automatically for Slurm
elif config.distributed.port > 0:
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
node_list = os.environ.get("SLURM_JOB_NODELIST")
if node_list is not None:
try:
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list]
)
config.distributed.init_method = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"),
port=config.distributed.port,
)
nnodes = int(os.environ.get("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
if ntasks_per_node is not None:
ntasks_per_node = int(ntasks_per_node)
else:
ntasks = int(os.environ.get("SLURM_NTASKS"))
nnodes = int(os.environ.get("SLURM_NNODES"))
assert ntasks % nnodes == 0
ntasks_per_node = int(ntasks / nnodes)
if ntasks_per_node == 1:
assert config.distributed.world_size % nnodes == 0
gpus_per_node = config.distributed.world_size // nnodes
node_id = int(os.environ.get("SLURM_NODEID"))
config.distributed.rank = node_id * gpus_per_node
else:
assert ntasks_per_node == config.distributed.world_size // nnodes
config.distributed.no_spawn = True
config.distributed.rank = int(os.environ.get("SLURM_PROCID"))
config.device_id = int(os.environ.get("SLURM_LOCALID"))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass