def get_src_and_dst_dp_ranks()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def get_src_and_dst_dp_ranks(self, is_except_routed_experts=False):
        """
        Return:
            The DP Group List for src & dst model [[DP-0], [DP-1] ... [DP-N]]
        """
        dst_dp_ranks = self.dst_model.all_ranks
        local_src_ranks = future.get(self.src_model.replicas[0].get_local_param_ranks())
        if local_src_ranks[0] is None or dst_dp_ranks is None:
            if self._debug:
                logger.warning(
                    f"DEBUG MODE! src_dp_ranks {local_src_ranks} or dst_dp_ranks: {dst_dp_ranks} is None, "
                    "make sure they have values in real application.")
                return local_src_ranks, dst_dp_ranks
            else:
                raise Exception(f"src_dp_ranks {local_src_ranks} or dst_dp_ranks {dst_dp_ranks} should not be None")
        dp_rank_to_ranks = defaultdict(list)
        for local_ranks, dp_rank in local_src_ranks:
            dp_rank_to_ranks[dp_rank].append(local_ranks[dp_rank])
        if is_except_routed_experts:
            # for weight except routed expert, ep_size using for data parallel.
            # TODO-1 The logic here is a little bit complicate, it would be better to move to a seperate function
            # TODO-2 The logic here is about HEP, would be better called from class ParameterSyncGroupwithHEP
            src_hep_size = self.num_src_expert_parallel * self.num_src_tensor_parallel
            new_dict = defaultdict(list)
            idx = 0
            for dp_rank, values in dp_rank_to_ranks.items():
                assert len(values) % src_hep_size == 0, (
                    f"len of values({len(values)}) for dp_rank {dp_rank} must be divisible by hep size({src_hep_size})"
                    f" when call get_src_and_dst_dp_ranks_for_except_routed_experts."
                )
                pp_blocks = [values[i:i + src_hep_size] for i in range(0, len(values), src_hep_size)]
                sub_blocks_per_pp = []
                for block in pp_blocks:
                    sub_block_size = src_hep_size // self.num_src_expert_parallel
                    sub_blocks = [block[i:i + sub_block_size] for i in range(0, src_hep_size, sub_block_size)]
                    sub_blocks_per_pp.append(sub_blocks)
                for i in range(self.num_src_expert_parallel):
                    merged_group = []
                    for sub_blocks in sub_blocks_per_pp:
                        merged_group.extend(sub_blocks[i])
                    new_dict[idx].extend(merged_group)
                    idx += 1
            src_dp_ranks = [i[1] for i in sorted(new_dict.items())]
        else:
            src_dp_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())]
        return src_dp_ranks, dst_dp_ranks