def _init_parallel_groups()

in src/nanotron/parallel/context.py [0:0]


    def _init_parallel_groups(self):
        """Initialize 3D parallelism's all process groups."""
        dist.barrier()
        ranks = np.arange(0, self.world_size).reshape(
            (
                self.expert_parallel_size,
                self.pipeline_parallel_size,
                self.data_parallel_size,
                self.context_parallel_size,
                self.tensor_parallel_size,
            )
        )
        self.world_ranks_to_pg = {}
        self.local_pg = self.create_new_group(ranks.reshape((-1, self.local_world_size)))
        assert int(os.environ.get("LOCAL_RANK")) == dist.get_rank(self.local_pg), "Local rank mismatch"

        # Relevant process groups containing the current rank
        self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3, 4)).reshape((-1, self.tensor_parallel_size)))
        self.cp_pg = self.create_new_group(ranks.transpose((4, 0, 1, 2, 3)).reshape((-1, self.context_parallel_size)))
        self.dp_pg = self.create_new_group(ranks.transpose((3, 4, 0, 1, 2)).reshape((-1, self.data_parallel_size)))
        self.pp_pg = self.create_new_group(ranks.transpose((2, 3, 4, 0, 1)).reshape((-1, self.pipeline_parallel_size)))
        self.ep_pg = self.create_new_group(
            ranks.transpose((1, 2, 3, 4, 0)).reshape((-1, self.expert_parallel_size))
        )  # TODO: ep should be a subset of dp

        # model parallel group = combination of tp and pp and exp for a given dp rank
        self.mp_pg = self.create_new_group(
            [
                ranks[:, :, dp_rank, cp_rank, :].reshape(-1)
                for cp_rank in range(self.context_parallel_size)
                for dp_rank in range(self.data_parallel_size)
            ]
        )

        self.dp_cp_pg = self.create_new_group(
            [
                ranks[ep_rank, pp_rank, :, :, tp_rank].reshape(-1)
                for tp_rank in range(self.tensor_parallel_size)
                for pp_rank in range(self.pipeline_parallel_size)
                for ep_rank in range(self.expert_parallel_size)
            ]
        )

        self.tp_and_ep_pg = self.create_new_group(
            [
                ranks[:, pp_rank, dp_rank, cp_rank, :].reshape(-1)
                for cp_rank in range(self.context_parallel_size)
                for pp_rank in range(self.pipeline_parallel_size)
                for dp_rank in range(self.data_parallel_size)
            ]
        )

        # self.tp_and_cp_pg = self.create_new_group(
        #     [
        #         ranks[ep_rank, pp_rank, dp_rank, :, :].reshape(-1)
        #         for ep_rank in range(self.expert_parallel_size)
        #         for pp_rank in range(self.pipeline_parallel_size)
        #         for dp_rank in range(self.data_parallel_size)
        #     ]
        # )

        self.world_rank_matrix: np.ndarray = ranks
        self.parallel_order = ["ep", "pp", "dp", "cp", "tp"]