in src/nanotron/parallel/context.py [0:0]
def _init_parallel_groups(self):
"""Initialize 3D parallelism's all process groups."""
dist.barrier()
ranks = np.arange(0, self.world_size).reshape(
(
self.expert_parallel_size,
self.pipeline_parallel_size,
self.data_parallel_size,
self.context_parallel_size,
self.tensor_parallel_size,
)
)
self.world_ranks_to_pg = {}
self.local_pg = self.create_new_group(ranks.reshape((-1, self.local_world_size)))
assert int(os.environ.get("LOCAL_RANK")) == dist.get_rank(self.local_pg), "Local rank mismatch"
# Relevant process groups containing the current rank
self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3, 4)).reshape((-1, self.tensor_parallel_size)))
self.cp_pg = self.create_new_group(ranks.transpose((4, 0, 1, 2, 3)).reshape((-1, self.context_parallel_size)))
self.dp_pg = self.create_new_group(ranks.transpose((3, 4, 0, 1, 2)).reshape((-1, self.data_parallel_size)))
self.pp_pg = self.create_new_group(ranks.transpose((2, 3, 4, 0, 1)).reshape((-1, self.pipeline_parallel_size)))
self.ep_pg = self.create_new_group(
ranks.transpose((1, 2, 3, 4, 0)).reshape((-1, self.expert_parallel_size))
) # TODO: ep should be a subset of dp
# model parallel group = combination of tp and pp and exp for a given dp rank
self.mp_pg = self.create_new_group(
[
ranks[:, :, dp_rank, cp_rank, :].reshape(-1)
for cp_rank in range(self.context_parallel_size)
for dp_rank in range(self.data_parallel_size)
]
)
self.dp_cp_pg = self.create_new_group(
[
ranks[ep_rank, pp_rank, :, :, tp_rank].reshape(-1)
for tp_rank in range(self.tensor_parallel_size)
for pp_rank in range(self.pipeline_parallel_size)
for ep_rank in range(self.expert_parallel_size)
]
)
self.tp_and_ep_pg = self.create_new_group(
[
ranks[:, pp_rank, dp_rank, cp_rank, :].reshape(-1)
for cp_rank in range(self.context_parallel_size)
for pp_rank in range(self.pipeline_parallel_size)
for dp_rank in range(self.data_parallel_size)
]
)
# self.tp_and_cp_pg = self.create_new_group(
# [
# ranks[ep_rank, pp_rank, dp_rank, :, :].reshape(-1)
# for ep_rank in range(self.expert_parallel_size)
# for pp_rank in range(self.pipeline_parallel_size)
# for dp_rank in range(self.data_parallel_size)
# ]
# )
self.world_rank_matrix: np.ndarray = ranks
self.parallel_order = ["ep", "pp", "dp", "cp", "tp"]