def _setup_ddp_vars_from_slurm_env()

in ignite/distributed/comp_models/native.py [0:0]


    def _setup_ddp_vars_from_slurm_env(environ: Dict[str, str]) -> Dict[str, Union[str, int]]:
        """Method to setup DDP env vars required by PyTorch from SLURM env"""
        # 1) Tools like enroot can have hooks to translate slurm env vars to RANK, LOCAL_RANK, WORLD_SIZE etc
        # See https://github.com/NVIDIA/enroot/blob/v3.1.0/conf/hooks/extra/50-slurm-pytorch.sh
        # 2) User can use torch.distributed.launch tool to schedule on N local GPUs using 1 node, 1 task by SLURM
        # To cover case 1), let's ensure that defined RANK == SLURM_PROCID, LOCAL_RANK == SLURM_LOCALID,
        #   WORLD_SIZE == SLURM_NTASKS. We will use defined MASTER_ADDR and MASTER_PORT instead of defining
        #   them by our means
        # To cover case 2), let's check that defined RANK >= SLURM_PROCID, LOCAL_RANK >= SLURM_LOCALID,
        #   WORLD_SIZE >= SLURM_NTASKS, SLURM_JOB_NUM_NODES == 1

        ddp_vars: Dict[str, Union[str, int, None]] = {
            "RANK": int(environ["SLURM_PROCID"]),
            "LOCAL_RANK": int(environ["SLURM_LOCALID"]),
            "WORLD_SIZE": int(environ["SLURM_NTASKS"]),
            "MASTER_ADDR": None,
            "MASTER_PORT": None,
        }

        pth_ddp_env_vars = {key: environ.get(key, None) for key in ddp_vars}
        defined_pth_ddp_env_vars = [v is not None for v in pth_ddp_env_vars.values()]
        if all(defined_pth_ddp_env_vars):
            nnodes = int(environ["SLURM_JOB_NUM_NODES"])
            if nnodes > 1:
                # ensure that all pth_ddp_env_vars are consistent with slurm vars
                for key in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
                    slurm_var = cast(int, ddp_vars[key])
                    pth_var = int(cast(str, pth_ddp_env_vars[key]))
                    if slurm_var != pth_var:
                        raise RuntimeError(
                            "Environment variable defined for PyTorch Distributed context is inconsistent with "
                            f"equivalent SLURM env variable. {key}: {pth_var} vs {slurm_var}\n"
                            f"SLURM vars: {ddp_vars}\n"
                            f"PTH vars: {pth_ddp_env_vars}\n"
                        )
            else:
                # ensure that PTH RANK >= SLURM_PROCID, PTH LOCAL_RANK >= SLURM_LOCALID,
                # PTH WORLD_SIZE >= SLURM_NTASKS
                for key in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
                    slurm_var = cast(int, ddp_vars[key])
                    pth_var = int(cast(str, pth_ddp_env_vars[key]))
                    if pth_var < slurm_var:
                        raise RuntimeError(
                            "Environment variable defined for PyTorch Distributed context is "
                            "inconsistent with equivalent SLURM env variable. "
                            f"We expect that {key}: {pth_var} >= {slurm_var}\n"
                            f"SLURM vars: {ddp_vars}\n"
                            f"PTH vars: {pth_ddp_env_vars}\n"
                        )
                    ddp_vars[key] = pth_var
            # set up MASTER_ADDR and MASTER_PORT from PTH
            ddp_vars["MASTER_ADDR"] = cast(str, pth_ddp_env_vars["MASTER_ADDR"])
            ddp_vars["MASTER_PORT"] = int(cast(str, pth_ddp_env_vars["MASTER_PORT"]))
        elif any(defined_pth_ddp_env_vars):
            # Let's warn user about PTH env variables that we could not taken into account
            warnings.warn(
                "We detected the following env variables: "
                f"{[(k, v) for k, v in pth_ddp_env_vars.items() if v is not None]},\n"
                "but will not take them into account as the following env vars are missing:"
                f"{[k for k, v in pth_ddp_env_vars.items() if v is None]},\n"
            )

        if ddp_vars["MASTER_ADDR"] is None:
            nodelist = environ["SLURM_JOB_NODELIST"]
            try:
                # use scontrol to expand hostname list
                hostnames = subprocess.check_output(["scontrol", "show", "hostnames", nodelist])
                method = "scontrol"
            except FileNotFoundError:
                # expand hostname list as scontrol
                hostnames = " ".join(_expand_hostlist(nodelist)).encode("utf-8")
                method = "ignite"
            # at least one hostname should be defined
            hostname_list = hostnames.split()
            if len(hostname_list) < 1:
                raise RuntimeError(f"No hostname detected in SLURM_JOB_NODELIST by {method} (nodelist={nodelist})")
            # master address is the first hostname of nodes list
            ddp_vars["MASTER_ADDR"] = str(hostname_list[0].decode("utf-8"))

        if ddp_vars["MASTER_PORT"] is None:
            # port should be the same over all process
            slurm_port = environ["SLURM_JOB_ID"]
            slurm_port = slurm_port[-4:]
            ddp_vars["MASTER_PORT"] = int(slurm_port) + 15000

        return cast(Dict[str, Union[str, int]], ddp_vars)