def _synchronize_params_except_routed_experts()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def _synchronize_params_except_routed_experts(self, requires_grad=None, validate=False, dryrun=False):
        self.check_and_setup_collective_group()

        self.check_and_fuse_lora(self._enable_lora, self.send_recv_actor_mappings)

        send_actors_list : List = []
        actor_mappings_list : List = []
        if self.concurrent_comm:
            if self.tp_num_mapping > 1:
                send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2]
                actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2]
                self._multi_thread_sync_for_tp_num_mapping_gt_1(
                    send_actors_list,
                    actor_mappings_list,
                    requires_grad=requires_grad,
                    filter_fn=self.params_except_routed_expert_filter,
                    param_group="except_routed",
                    dryrun=dryrun
                )
            else:
                send_actors_list = [self.sorted_send_actors]
                actor_mappings_list = [self.send_recv_actor_mappings]
                self._multi_thread_sync_for_tp_num_mapping_eq_1(
                    send_actors_list,
                    actor_mappings_list,
                    requires_grad=requires_grad,
                    filter_fn=self.params_except_routed_expert_filter,
                    param_group="except_routed"
                )
        else:
            actor_mappings_list = [self.send_recv_actor_mappings]
            self._single_thread_sync(
                actor_mappings_list,
                requires_grad=requires_grad,
                filter_fn=self.params_except_routed_expert_filter,
                param_group="except_routed"
            )

        self.check_and_unfuse_lora(self._enable_lora, self.send_recv_actor_mappings)

        self.validate_sync_results_parallel(
            actor_mappings_list,
            requires_grad,
            validate,
            filter_fn=self.params_except_routed_expert_filter,
            param_group="except_routed"
        )

        self.check_and_destroy_collective_group()

        logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}")