def _set_sync_param_names()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def _set_sync_param_names(self, send_actor, recv_actor, requires_grad=None, filter_fn=None, param_group="default", should_map_name=True):
        if requires_grad is None:
            requires_grad = True
        if self._enable_lora:
            # TODO(jiangle.jl): support freeze layer.
            requires_grad = False
        assert param_group in ("default", "routed", "except_routed"), (
            f"param_group must be one of 'default', 'routed', or 'except_routed', got {param_group}."
        )

        if self.num_src_pipeline_stage > 1:
            dst_pipe_rank = self.get_actor_pipe_rank(recv_actor)
            dst_layer_offset = self.get_or_cache(recv_actor, "get_pipeline_stage_layer_offset")
            dst_src_mappings = future.get(send_actor.build_pipeline_layer_name_mapping.remote(
                                          self.num_dst_pipeline_stage, dst_pipe_rank, dst_layer_offset,
                                          requires_grad=requires_grad))
            dst_names = list(dst_src_mappings.keys())
            src_names = list(dst_src_mappings.values())
        else:
            src_names = dst_names = future.get(send_actor.get_parameter_names.remote(requires_grad=requires_grad))

        if self._enable_lora:
            src_names = [ele for ele in src_names if LORA_WEIGHT_PREFIX not in ele]
            dst_names = [ele for ele in dst_names if LORA_WEIGHT_PREFIX not in ele]

        if filter_fn is not None:
            src_names = filter_fn(src_names)
            dst_names = filter_fn(dst_names)

        synchronizer = get_synchronizer(self.src_model, self.dst_model)
        if should_map_name:
            src_names, dst_names = synchronizer.map_name_from_src_to_dst(send_actor, recv_actor, src_names, dst_names)
        else:
            # For routed experts which need to regroup expert first in trainer actors.
            synchronizer.map_name_from_src_to_dst(send_actor, recv_actor, src_names, dst_names)
        self.actor2synchronizer[send_actor] = synchronizer
        future.wait(send_actor.set_synchronizer.remote(synchronizer))

        self.check_param_names(send_actor, recv_actor, src_names, dst_names)
        dst_model = self.actor2model[recv_actor]
        if self.tp_num_mapping > 1 and ((not dst_model.use_vllm_backend and param_group != "routed") or dst_model.use_vllm_backend):
            key = (recv_actor, recv_actor, param_group)
            if key not in self._send_recv_param_names:
                self._send_recv_param_names[key] = (dst_names, dst_names)
            else:
                dst_names0 = self._send_recv_param_names[key][0]
                dst_names0 += dst_names
                self._send_recv_param_names[key] = (dst_names0, dst_names0)
        if not self.synchronizer.is_parameter_changed:
            pipe_stage = self.get_actor_pipe_rank(send_actor)
            refs = []
            refs.append(send_actor.set_sync_parameters.remote(src_names, pipe_stage))
            refs.append(recv_actor.set_sync_parameters.remote(dst_names, pipe_stage))
            future.get(refs)
        return src_names, dst_names