def build_rank_mapping_two_stage()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def build_rank_mapping_two_stage(self, add_recv_actor_fn=None):
        # setup rank mapping for src parameter and dst parameter
        # get rank for one src_model, without model replicas

        if add_recv_actor_fn is None:
            add_recv_actor_stage1_fn = self.add_recv_actor
            add_recv_actor_stage2_fn = self.add_recv_actor_stage2
        else:
            assert len(add_recv_actor_fn) == 2, (
                "The length of add_recv_actor_fn should be 2. The first one is a function handler for communication stage 1, "
                "while the second one is a function handler for communication stage 2."
            )
            add_recv_actor_stage1_fn = add_recv_actor_fn[0]
            add_recv_actor_stage2_fn = add_recv_actor_fn[1]

        src_ranks, dst_ranks = self.get_src_and_dst_dp_ranks(is_except_routed_experts=True)
        if self._debug and (src_ranks[0] is None or dst_ranks is None):
            return

        replica_rank_iter = cycle(iter(src_ranks))

        logger.debug(f"src_ranks: {src_ranks}")
        logger.debug(f"dst_ranks: {dst_ranks}")
        assert self.num_dst_tensor_parallel % self.num_src_tensor_parallel == 0, \
            "currently we require mod value equals to zero for tensor_model_parallel_size of dst_model and that of src_model while " + \
            f"src model {self.src_model.name}(TP={self.num_src_tensor_parallel}) and " + \
            f"dst model {self.dst_model.name}(TP={self.num_dst_tensor_parallel})"
        assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0

        def split_ranks_by_tp_and_ep_size(ranks, tp_size, ep_size):
            if ep_size > 1:
                sort_ranks_on_grouped_tp = []
                index = 0
                tp_index = 0
                for _ in range(len(ranks)):
                    sort_ranks_on_grouped_tp.append(index)
                    if tp_index < tp_size - 1:
                        index += 1
                        tp_index += 1
                    else:
                        start_index = index + 1 - tp_size
                        index = start_index + (ep_size * tp_size)
                        tp_index = 0
                    if index >= len(ranks):
                        index = (index + tp_size) % len(ranks)
            else:
                sort_ranks_on_grouped_tp = ranks
            return [sort_ranks_on_grouped_tp[i:i + tp_size] for i in range(0, len(sort_ranks_on_grouped_tp), tp_size)]

        pair_list = []
        p2p_list = []
        src_replica_offset = 0
        lb_dst_offset_pq_dict = {}

        for dst_replica_ranks in dst_ranks:
            src_replica_ranks = next(replica_rank_iter)
            # for weight except routed expert, ep_size using for data parallel.
            src_replica_ranks_group = split_ranks_by_tp_and_ep_size(src_replica_ranks, self.num_src_tensor_parallel, 1)
            dst_replica_ranks_group = split_ranks_by_tp_and_ep_size(dst_replica_ranks, self.num_dst_tensor_parallel, self.num_dst_expert_parallel)
            logger.debug(f"src_replica_ranks_group: {src_replica_ranks_group}")
            logger.debug(f"dst_replica_ranks_group: {dst_replica_ranks_group}")
            pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage

            assert pipe_map_interval >= 1, \
                f"dst_pp expected to divide src_pp, while src_pp {self.num_src_pipeline_stage} and dst_pp {self.num_dst_pipeline_stage}"

            # stage 1: comm pairs that broadcast params from trainer to inference model
            # Each rank in trainer holds weights for tp_num_mapping ranks in inference model.
            # For example: trainer_tp = 2, inference_tp = 4 => tp_num_mapping = inference_tp // trainer_tp = 2
            # Weight mapping from training to inference:
            #   [0] -> [0', 1']
            #   [1] -> [2', 3']
            # To avoid p2p communication on the same gpu, we only broadcast params to first rank in weight_mapping_group.
            # Comm mapping from training to inference:
            #   [0] -> [0']
            #   [1] -> [2']
            # Firstly, pre-allocate for those gpu collisions
            uncollided_index_to_start_j = {}
            for i, src_tp_group in enumerate(src_replica_ranks_group):
                if i < src_replica_offset:
                    continue
                j = (i - src_replica_offset) // pipe_map_interval
                if j == self.num_dst_pipeline_stage:
                    src_replica_offset = i
                    break
                if self.tp_num_mapping == 1:
                    start =  0
                else:
                    mod_i = (i - src_replica_offset) % self.tp_num_mapping
                    start = mod_i if (i - src_replica_offset) < self.tp_num_mapping else (self.tp_num_mapping - mod_i - 1) % self.tp_num_mapping
                for s_idx, src_rank in enumerate(src_tp_group):
                    dst_rank, is_collide = self.get_load_balance_dst_rank(
                        lb_dst_offset_pq_dict,
                        s_idx,
                        start,
                        src_rank,
                        dst_replica_ranks_group,
                        j,
                        pre_allocate=True
                    )
                    if is_collide:
                        add_recv_actor_stage1_fn(src_rank, dst_rank)
                        pair_list.append((src_rank, dst_rank))
                    else:
                        assert dst_rank is None
                        uncollided_index_to_start_j.update({(i, s_idx) : (start, j)})

            # Then, allocate src_ranks without gpu collisions
            for i, src_tp_group in enumerate(src_replica_ranks_group):
                for s_idx, src_rank in enumerate(src_tp_group):
                    if (i, s_idx) not in uncollided_index_to_start_j:
                        continue

                    start, j = uncollided_index_to_start_j.get((i, s_idx))
                    dst_rank, _ = self.get_load_balance_dst_rank(
                        lb_dst_offset_pq_dict,
                        s_idx,
                        start,
                        src_rank,
                        dst_replica_ranks_group,
                        j,
                        pre_allocate=False
                    )
                    add_recv_actor_stage1_fn(src_rank, dst_rank)
                    pair_list.append((src_rank, dst_rank))

            # stage 2: comm pairs that broadcast params from first rank to the other ranks for each weight_mapping_group
            # Comm mapping in each weight_mapping_group of inference:
            #   [0'] -> [1']
            #   [2'] -> [3']
            recv_ranks = [pair[1] for pair in pair_list]
            def p2p_pair_grouping(tuples):
                for s_idx, src_rank in enumerate(tuples):
                    for d_idx, dst_rank in enumerate(tuples):
                        if s_idx == d_idx or src_rank not in recv_ranks: # pylint: disable=cell-var-from-loop
                            continue
                        add_recv_actor_stage2_fn(src_rank, dst_rank)
                        p2p_list.append((src_rank, dst_rank))

            for dst_tp_group in dst_replica_ranks_group:
                dst_tp_group = split_ranks_by_tp_and_ep_size(dst_tp_group, self.tp_num_mapping, 1)
                for tuples in dst_tp_group:
                    p2p_pair_grouping(tuples)

        logger.info(f"comm pair_list <train_rank, inference_rank>: {pair_list}")
        logger.info(f"comm p2p_list <inference_rank, inference_rank>: {p2p_list}")