in chatlearn/synchronizer/parameter_sync.py [0:0]
def _set_sync_param_names(self, send_actor, recv_actor, requires_grad=None, filter_fn=None, param_group="default", should_map_name=True):
if requires_grad is None:
requires_grad = True
if self._enable_lora:
# TODO(jiangle.jl): support freeze layer.
requires_grad = False
assert param_group in ("default", "routed", "except_routed"), (
f"param_group must be one of 'default', 'routed', or 'except_routed', got {param_group}."
)
if self.num_src_pipeline_stage > 1:
dst_pipe_rank = self.get_actor_pipe_rank(recv_actor)
dst_layer_offset = self.get_or_cache(recv_actor, "get_pipeline_stage_layer_offset")
dst_src_mappings = future.get(send_actor.build_pipeline_layer_name_mapping.remote(
self.num_dst_pipeline_stage, dst_pipe_rank, dst_layer_offset,
requires_grad=requires_grad))
dst_names = list(dst_src_mappings.keys())
src_names = list(dst_src_mappings.values())
else:
src_names = dst_names = future.get(send_actor.get_parameter_names.remote(requires_grad=requires_grad))
if self._enable_lora:
src_names = [ele for ele in src_names if LORA_WEIGHT_PREFIX not in ele]
dst_names = [ele for ele in dst_names if LORA_WEIGHT_PREFIX not in ele]
if filter_fn is not None:
src_names = filter_fn(src_names)
dst_names = filter_fn(dst_names)
synchronizer = get_synchronizer(self.src_model, self.dst_model)
if should_map_name:
src_names, dst_names = synchronizer.map_name_from_src_to_dst(send_actor, recv_actor, src_names, dst_names)
else:
# For routed experts which need to regroup expert first in trainer actors.
synchronizer.map_name_from_src_to_dst(send_actor, recv_actor, src_names, dst_names)
self.actor2synchronizer[send_actor] = synchronizer
future.wait(send_actor.set_synchronizer.remote(synchronizer))
self.check_param_names(send_actor, recv_actor, src_names, dst_names)
dst_model = self.actor2model[recv_actor]
if self.tp_num_mapping > 1 and ((not dst_model.use_vllm_backend and param_group != "routed") or dst_model.use_vllm_backend):
key = (recv_actor, recv_actor, param_group)
if key not in self._send_recv_param_names:
self._send_recv_param_names[key] = (dst_names, dst_names)
else:
dst_names0 = self._send_recv_param_names[key][0]
dst_names0 += dst_names
self._send_recv_param_names[key] = (dst_names0, dst_names0)
if not self.synchronizer.is_parameter_changed:
pipe_stage = self.get_actor_pipe_rank(send_actor)
refs = []
refs.append(send_actor.set_sync_parameters.remote(src_names, pipe_stage))
refs.append(recv_actor.set_sync_parameters.remote(dst_names, pipe_stage))
future.get(refs)
return src_names, dst_names