optimum/habana/distributed/__init__.py (17 lines of code) (raw):
import os
import torch
from .distributed_runner import DistributedRunner
from .fast_ddp import all_reduce_gradients
def rank_and_world(group=None):
"""
Returns (rank, world_size) from the optionally-specified group, otherwise
from the default group, or if non-distributed just returns (0, 1)
"""
if torch.distributed.is_initialized() and group is None:
group = torch.distributed.GroupMember.WORLD
if group is None:
world_size = 1
rank = 0
else:
world_size = group.size()
rank = group.rank()
return rank, world_size
_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))
def local_rank():
return _LOCAL_RANK