in training/utils/distributed.py [0:0]
def create_new_process_group(group_size):
"""
Creates process groups of a gives `group_size` and returns
process group that current GPU participates in.
`group_size` must divide the total number of GPUs (world_size).
Modified from
https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
Args:
group_size (int): number of GPU's to collaborate for sync bn
"""
assert group_size > 0
world_size = torch.distributed.get_world_size()
if world_size <= 8:
if group_size > world_size:
logging.warning(
f"Requested group size [{group_size}] > world size [{world_size}]. "
"Assuming local debug run and capping it to world size."
)
group_size = world_size
assert world_size >= group_size
assert world_size % group_size == 0
group = None
for group_num in range(world_size // group_size):
group_ids = range(group_num * group_size, (group_num + 1) * group_size)
cur_group = torch.distributed.new_group(ranks=group_ids)
if torch.distributed.get_rank() // group_size == group_num:
group = cur_group
# can not drop out and return here, every process must go through creation of all subgroups
assert group is not None
return group