chatlearn/synchronizer/megatron_megatron.py (30 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """megatron to megatron synchronizer""" from chatlearn.utils import future from .base import BaseSync class MegatronMegatronSync(BaseSync): """megatron to megatron synchronizer""" def _get_dst_name(self, src_name, src_prefix, dst_prefix): if src_prefix: dst_name = src_name[len(src_prefix):] else: dst_name = dst_prefix + src_name return dst_name def set_model_prefix(self, src_names, dst_names): dst_prefix = None src_prefix = None for sname in src_names: for dname in dst_names: if sname in dname: prefix = dname[:dname.index(sname)] dst_prefix = prefix return src_prefix, dst_prefix elif dname in sname: prefix = sname[:sname.index(dname)] src_prefix = prefix return src_prefix, dst_prefix if dst_prefix is None and src_prefix is None: raise RuntimeError("Cannot find prefix") return src_prefix, dst_prefix def map_name_from_src_to_dst(self, send_actor, recv_actor, src_names, dst_names): dst_names_ref = future.get(recv_actor.get_parameter_names.remote(requires_grad=False)) src_prefix, dst_prefix = self.set_model_prefix(src_names, dst_names_ref) dst_names = [self._get_dst_name(name, src_prefix, dst_prefix) for name in dst_names] return src_names, dst_names