def _synchronize_all_moe_parameters()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False, dryrun=False):
        self.check_and_setup_collective_group()

        send_actors_list : List = [
            self.sorted_send_actors,
            self.sorted_send_actors_stage2
        ]
        actor_mappings_list : List = [
            self.send_recv_actor_mappings,
            self.send_recv_actor_mappings_stage2,
            self.send_actors_to_regroup_routed_experts,
        ]

        self.check_and_fuse_lora(self._enable_lora, actor_mappings_list)

        if self.concurrent_comm:
            assert self.dst_model.use_vllm_backend

            max_workers = self._calculate_max_workers(self.send_actors_to_regroup_routed_experts)
            if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER:
                # allgather routed experts only
                self.sync_allgather_multi_threads(
                    [self.send_actors_to_regroup_routed_experts],
                    max_workers=max_workers,
                    requires_grad=requires_grad,
                    group_name=self.group_name + "_allgather",
                    filter_fn=self.routed_experts_filter)
            elif self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL:
                if not dryrun:
                    logger.info("start to alltoall router experts. ")
                    start_time = time.time()
                    # alltoall routed experts only
                    self.sync_alltoall_multi_threads(
                        [self.send_actors_to_regroup_routed_experts],
                        max_workers=max_workers,
                        requires_grad=requires_grad,
                        filter_fn=self.routed_experts_filter)
                    logger.info("complete to alltoall router experts using {time.time()-start_time:.2f} seconds ")
            # sync everything to inference model
            if self.tp_num_mapping == 1:
                logger.info("start to sync all moe experts")
                send_actors_list = [self.sorted_send_actors]
                actor_mappings_list = [self.send_recv_actor_mappings]
                self._multi_thread_sync_for_tp_num_mapping_eq_1(
                    send_actors_list,
                    actor_mappings_list,
                    requires_grad=requires_grad,
                    filter_fn=None,
                    param_group="default",
                    dryrun=dryrun
                )
                logger.info("complete to sync all moe experts")

            elif self.tp_num_mapping > 1:
                # First, synchronize routed experts.
                logger.info("start to sync routed expert weights.")
                start_time = time.time()
                self._synchronize_routed_experts(requires_grad=requires_grad, validate=validate, dryrun=dryrun)
                logger.info(f"complete to sync routed expert weights. [stage1-1] using {time.time()-start_time:.2f} seconds")
                self.clear_cache(
                    sorted_send_actors_list = [
                        self.send_actors_to_regroup_routed_experts,
                        self.sorted_send_actors_for_routed_experts
                    ],
                    rank_mapping_list=[
                        self.send_recv_actor_mappings_for_routed_experts
                    ]
                )

                # Then, synchronize parameters except routed experts
                logger.info("start to sync parameters except routed eperts.")
                self._synchronize_params_except_routed_experts(requires_grad=requires_grad, validate=validate, dryrun=dryrun)
                logger.info("complete to sync parameters except routed experts.")

                self.reset_synchronizer()

                self.clear_cache(
                    sorted_send_actors_list = [
                        self.sorted_send_actors,
                        self.sorted_send_actors_stage2,
                    ],
                    rank_mapping_list = [
                        self.send_recv_actor_mappings,
                        self.send_recv_actor_mappings_stage2
                    ]
                )
            else:
                raise NotImplementedError(
                    f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})"
                    f"to smaller tp size ({self.num_dst_tensor_parallel}) currently."
                )

        else:
            raise NotImplementedError(
                "ChatLearn supports only concurrent_comm for training models with HEP enabled and inference with vLLM"
            )

        self.check_and_unfuse_lora(self._enable_lora, actor_mappings_list)

        self.validate_sync_results_parallel(actor_mappings_list, requires_grad, validate)

        self.check_and_destroy_collective_group()

        self.reset_synchronizer()

        logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}")