def build_rank_mapping_for_ep()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def build_rank_mapping_for_ep(self, add_recv_actor_fn=None):
        # setup rank mapping for src parameter and dst parameter
        # get rank for one src_model, without model replicas

        if add_recv_actor_fn is None:
            add_recv_actor_fn = self.add_recv_actor

        src_dp_ranks, dst_dp_ranks = self.get_src_and_dst_dp_ranks()
        if self._debug and (src_dp_ranks[0] is None or dst_dp_ranks is None):
            return

        assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \
            f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}"
        if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1:
            replica_rank_iter = cycle(reversed(src_dp_ranks))
        else:
            replica_rank_iter = cycle(iter(src_dp_ranks))
        logger.debug(f"src_dp_ranks: {src_dp_ranks}")
        logger.debug(f"dst_dp_ranks: {dst_dp_ranks}")

        assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0

        def split_ranks_by_ep_and_tp_size(ranks,
                                          tp_size : int = 1,
                                          ep_size : int = 1):
            tp_and_ep_size = tp_size * ep_size
            return [[ranks[i:i + tp_size] for i in range(j, j + tp_and_ep_size, tp_size)] for j in range(0, len(ranks), tp_and_ep_size)]

        src_replica_ranks2offset = {}
        is_first_time_set_send_actors = True
        for dst_replica_ranks in dst_dp_ranks:
            src_replica_ranks = next(replica_rank_iter)
            if tuple(src_replica_ranks) not in src_replica_ranks2offset:
                src_replica_ranks2offset[tuple(src_replica_ranks)] = 0
                is_first_time_set_send_actors = True
            else:
                is_first_time_set_send_actors = False

            src_replica_ranks_group = split_ranks_by_ep_and_tp_size(src_replica_ranks, self.num_src_tensor_parallel, self.num_src_expert_parallel)
            # Since dst replica is vllm and it doesn't have ep, the function will organize dst_replica_ranks_group as [pp[tp]] naturally.
            dst_replica_ranks_group = split_ranks_by_ep_and_tp_size(dst_replica_ranks, self.num_dst_tensor_parallel, self.num_dst_expert_parallel)

            if is_first_time_set_send_actors:
                self.set_send_actors_to_regroup_routed_experts(src_replica_ranks_group)
                self.add_routed_experts_regrouping_actor(self.src_model, src_replica_ranks_group)

            if add_recv_actor_fn is self.empty_add_recv_actor:
                continue

            pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage
            for i, src_ep_and_tp_group in enumerate(src_replica_ranks_group):
                j = i // pipe_map_interval
                first_src_tp_group = src_ep_and_tp_group[0]
                assert len(dst_replica_ranks_group[j][0]) % len(first_src_tp_group) == 0, (
                    "TP size of dst model should be times of src model, "
                    f"but got {len(dst_replica_ranks_group[j][0])} and {len(first_src_tp_group)}"
                )
                len_dst_div_src = len(dst_replica_ranks_group[j][0]) // len(first_src_tp_group)
                concated_src_tp_group = []
                offset = src_replica_ranks2offset[tuple(src_replica_ranks)]
                # cycled concatenate src tp group to ensure len(concat_src_tp_group) == len(dst_replica_ranks_group[j][0])
                for k in range(len_dst_div_src):
                    concated_src_tp_group.extend(src_ep_and_tp_group[int((offset + k) % len(src_ep_and_tp_group))])
                for src_rank, dst_rank in zip(concated_src_tp_group, dst_replica_ranks_group[j][0]):
                    add_recv_actor_fn(src_rank, dst_rank)
                src_replica_ranks2offset[tuple(src_replica_ranks)] = int(
                    (src_replica_ranks2offset[tuple(src_replica_ranks)] + len_dst_div_src) % len(src_ep_and_tp_group)
                )

        if self._debug:
            def debug_msg_for_actor_mappings(actor_mapping):
                if actor_mapping is None:
                    return

                for k, v_list in actor_mapping.items():
                    for v in v_list:
                        logger.debug(f"actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}")

            debug_msg_for_actor_mappings(self.send_recv_actor_mappings)
            debug_msg_for_actor_mappings(self.send_recv_actor_mappings_for_routed_experts)

            for regroup_actors in self.send_actors_to_regroup_routed_experts:
                count += 1
                cat_str = "_".join(str(self.actor2rank[actor]) for actor in regroup_actors)
                logger.info(f"{self._comm_type_to_regroup_routed_experts} actors: {cat_str}")
            for k, v_list in self.send_recv_actor_mappings.items():
                for v in v_list:
                    logger.info(f"send_recv_actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}")