def init_processes()

in dataflux_pytorch/benchmark/checkpointing/simulated/llama2.py [0:0]


def init_processes() -> None:
    """Initializes the distributed environment."""
    world_size = int(os.environ["WORLD_SIZE"])
    job_index = int(os.environ.get("JOB_INDEX", 0))
    job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX", 0))
    processes_in_job = int(os.environ.get("PROCESSES_IN_JOB", 1))
    rank = job_index * processes_in_job + job_completion_index
    os.environ["NODE_RANK"] = str(rank)

    configure_master_addr()
    # Using gloo backend since the simulated version runs on CPU.
    torch.distributed.init_process_group(backend='gloo',
                                         rank=rank,
                                         world_size=world_size)