def init_torch_distributed()

in bench_cluster/communication/utils.py [0:0]


def init_torch_distributed(backend: str, local_rank: int):
    dist.init_process_group(backend, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"]))
    torch.cuda.set_device(local_rank)
    print_rank_0(f"Initializing distributed backend: {backend}")
    print_rank_0(f"RANK: {os.environ['RANK']}")
    print_rank_0(f"WORLD_SIZE: {os.environ['WORLD_SIZE']}")