def create_new_process_group()

in training/utils/distributed.py [0:0]


def create_new_process_group(group_size):
    """
    Creates process groups of a gives `group_size` and returns
    process group that current GPU participates in.

    `group_size` must divide the total number of GPUs (world_size).

    Modified from
    https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60

    Args:
        group_size (int): number of GPU's to collaborate for sync bn
    """

    assert group_size > 0

    world_size = torch.distributed.get_world_size()
    if world_size <= 8:
        if group_size > world_size:
            logging.warning(
                f"Requested group size [{group_size}] > world size [{world_size}]. "
                "Assuming local debug run and capping it to world size."
            )
            group_size = world_size
    assert world_size >= group_size
    assert world_size % group_size == 0

    group = None
    for group_num in range(world_size // group_size):
        group_ids = range(group_num * group_size, (group_num + 1) * group_size)
        cur_group = torch.distributed.new_group(ranks=group_ids)
        if torch.distributed.get_rank() // group_size == group_num:
            group = cur_group
            # can not drop out and return here, every process must go through creation of all subgroups

    assert group is not None
    return group