in src/utils.py [0:0]
def init_distributed(port=40101, rank_and_world_size=(None, None)):
if dist.is_available() and dist.is_initialized():
return dist.get_world_size(), dist.get_rank()
rank, world_size = rank_and_world_size
os.environ['MASTER_ADDR'] = 'localhost'
if (rank is None) or (world_size is None):
try:
world_size = int(os.environ['SLURM_NTASKS'])
rank = int(os.environ['SLURM_PROCID'])
os.environ['MASTER_ADDR'] = os.environ['HOSTNAME']
except Exception:
logger.info('distributed training not available')
world_size, rank = 1, 0
return world_size, rank
try:
os.environ['MASTER_PORT'] = str(port)
torch.distributed.init_process_group(
backend='nccl',
world_size=world_size,
rank=rank)
except Exception:
world_size, rank = 1, 0
logger.info('distributed training not available')
return world_size, rank