def maybe_init_distributed()

in egg/core/distributed.py [0:0]


def maybe_init_distributed(args) -> DistributedContext:
    assert not hasattr(
        args, "distributed_context"
    ), "distributed context is already initialized?!"
    # default, non-distributed context
    context = DistributedContext(
        is_distributed=False, rank=0, local_rank=0, world_size=1, mode="none"
    )

    launch_keys = ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK", "LOCAL_RANK"]
    slurm_keys = [
        "SLURM_LOCALID",
        "SLURM_PROCID",
        "SLURM_NTASKS",
        "SLURM_NODEID",
        "SLURM_JOB_NODELIST",
    ]

    # is it torch.distributed.launch?
    if all(key in os.environ for key in launch_keys):
        init_method = "env://"
        world_size = int(os.environ["WORLD_SIZE"])
        rank = int(os.environ["RANK"])
        local_rank = int(os.environ["LOCAL_RANK"])
        context = DistributedContext(
            is_distributed=True,
            rank=rank,
            world_size=world_size,
            local_rank=local_rank,
            mode="launch",
        )
        dist.init_process_group(
            backend="nccl", init_method=init_method, world_size=world_size, rank=rank
        )
    # is it slurm?
    elif all(key in os.environ for key in slurm_keys):
        init_method = "env://"
        local_rank = int(os.environ["SLURM_LOCALID"])
        rank = int(os.environ["SLURM_PROCID"])
        world_size = int(os.environ["SLURM_NTASKS"])

        hostnames = subprocess.check_output(
            ["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
        )
        leader_addr = hostnames.split()[0].decode("utf-8")

        os.environ["MASTER_ADDR"] = leader_addr
        os.environ["MASTER_PORT"] = str(args.distributed_port)
        os.environ["WORLD_SIZE"] = str(world_size)
        os.environ["RANK"] = str(rank)

        if world_size > 1:
            # no point in being distributed if it is alone
            context = DistributedContext(
                is_distributed=True,
                rank=rank,
                local_rank=local_rank,
                world_size=world_size,
                mode="slurm",
            )
            dist.init_process_group(
                backend="nccl",
                init_method=init_method,
                world_size=world_size,
                rank=rank,
            )

    return context