chatlearn/tools/megatron_checkpoint_utils.py (156 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Tools to convert megatron checkpoint with different parallel strategies. Add support for reward model conversion. """ # pylint: disable=wildcard-import,exec-used import sys import importlib def get_indent_count(string): count = 0 for s in string: if s == ' ': count += 1 else: return count def repair_entry_file(source): source = source.replace("choices=['GPT', 'BERT']", "choices=['GPT', 'BERT', 'REWARD']") return source def detect_and_insert_code(lines, pattern, new_code, additional_indent, line_offset): type_line_number, type_line = [(line_number, line) for line_number, line in enumerate(lines) if pattern in line][0] indent = get_indent_count(type_line) + additional_indent new_lines = [line for line in new_code.split('\n') if line.strip()] added_lines = [] for line in new_lines: added_lines.append(" "*indent + line) lines = lines[:type_line_number+line_offset] + added_lines + lines[type_line_number+line_offset:] return lines def repair_loader_model_provider(lines): # Insert before following code, so line_offset=-2 # else: # raise Exception(f'unrecognized model type: {args.model_type}') pattern = 'unrecognized model type' new_code = \ """ elif args.model_type == 'REWARD': from examples.megatron.models.reward_model import model_provider margs.model_type = ModelType.encoder_or_decoder """ indent = -4 line_offset = -2 return detect_and_insert_code(lines, pattern, new_code, indent, line_offset) def repair_saver_model_provider(lines): return repair_loader_model_provider(lines) def repair_loader_put_reward(lines): pattern = 'queue.put("done")' new_code = \ """ if md.model_type == 'REWARD': print("Sending LM Pooler") message = { "weight1": models[0].pooler_head.dense1.weight.data, "bias1": models[0].pooler_head.dense1.bias.data, "weight2": models[0].pooler_head.dense2.weight.data, "bias2": models[0].pooler_head.dense2.bias.data, } queue_put("pooler_head", message) """ return detect_and_insert_code(lines, pattern, new_code, 0, 0) def repair_saver_get_reward(lines): pattern = 'if msg != "done":' new_code = \ """ if msg != "done" and msg["name"] == "pooler_head": if not hasattr(models[0], 'pooler_head'): print("ERROR: got a pooler_head, but model does not have one") exit(1) print("received pooler_head") head_weight1 = msg.pop("weight1") head_bias1 = msg.pop("bias1") head_weight2 = msg.pop("weight2") head_bias2 = msg.pop("bias2") for tp_rank in range(args.target_tensor_parallel_size): models[tp_rank].pooler_head.dense1.weight.data.copy_(head_weight1) models[tp_rank].pooler_head.dense1.bias.data.copy_(head_bias1) models[tp_rank].pooler_head.dense2.weight.data.copy_(head_weight2) models[tp_rank].pooler_head.dense2.bias.data.copy_(head_bias2) check_message(msg) msg = queue_get() """ return detect_and_insert_code(lines, pattern, new_code, 0, 0) # MCore def repair_mcore_block_key(source): source = source.replace('"BERT" : "encoder",', '"BERT" : "encoder", "REWARD" : "decoder"') return source def repair_import_utils(source): source = source.replace('from utils import get_mcore_transformer_block_key, print_memory_usage', 'from tools.checkpoint.utils import get_mcore_transformer_block_key, print_memory_usage') return source def repair_loader_mcore_import_error(source): return repair_import_utils(source) def repair_saver_mcore_import_error(source): source = source.replace('from setter import ModelSetter', 'from tools.checkpoint.setter import ModelSetter') return repair_import_utils(source) def repair_loader_mcore_model_provider(lines): # Insert before following code, so line_offset=-2 # else: # raise Exception(f'unrecognized model type: {args.model_type}') pattern = 'unrecognized model type' new_code = \ """ elif args.model_type == 'REWARD': from examples.megatron.models.mcore_reward_model import model_provider margs.model_type = ModelType.encoder_or_decoder """ indent = -4 line_offset = -2 return detect_and_insert_code(lines, pattern, new_code, indent, line_offset) def repair_saver_mcore_model_provider(lines): return repair_loader_mcore_model_provider(lines) def repair_loader_mcore_put_reward(lines): return repair_loader_put_reward(lines) def repair_saver_mcore_get_reward(lines): pattern = 'if msg != "done":' new_code = \ """ if msg != "done" and msg["name"] == "pooler_head": if not hasattr(models[pp_rank][0][0], 'pooler_head'): print("ERROR: got a pooler_head, but model does not have one") exit(1) print("received pooler_head") head_weight1 = msg.pop("weight1") head_bias1 = msg.pop("bias1") head_weight2 = msg.pop("weight2") head_bias2 = msg.pop("bias2") for model in pp_local_models: model.pooler_head.dense1.weight.data.copy_(head_weight1) model.pooler_head.dense1.bias.data.copy_(head_bias1) model.pooler_head.dense2.weight.data.copy_(head_weight2) model.pooler_head.dense2.bias.data.copy_(head_bias2) check_message(msg) msg = queue_get() """ return detect_and_insert_code(lines, pattern, new_code, 0, 0) def exist_checkpoint_util(): spec = importlib.util.find_spec('tools.checkpoint.util') return spec is not None def repair_loader_llama_mistral(source): source = source.replace('args.seq_length = 4096', 'args.seq_length = model_args["max_position_embeddings"]') return source class CheckpointUtilsImporter: """CheckpointUtilsImporter""" def __init__(self, *args): self.module_names = args self.path = None def find_module(self, fullname, path=None): if fullname in self.module_names: # save the path so that it could be used later by `load_module` self.path = path return self return None def repair_code(self, source, module_name): if module_name in ['util', 'convert']: source = repair_entry_file(source) elif module_name == 'loader_megatron': lines = source.split('\n') lines = repair_loader_model_provider(lines) lines = repair_loader_put_reward(lines) source = '\n'.join(lines) elif module_name == 'saver_megatron': lines = source.split('\n') lines = repair_saver_model_provider(lines) lines = repair_saver_get_reward(lines) source = '\n'.join(lines) elif module_name == 'loader_llama_mistral': source = repair_loader_llama_mistral(source) elif module_name == 'loader_mcore': source = repair_loader_mcore_import_error(source) lines = source.split('\n') lines = repair_loader_mcore_model_provider(lines) lines = repair_loader_mcore_put_reward(lines) source = '\n'.join(lines) elif module_name == 'saver_mcore': source = repair_saver_mcore_import_error(source) lines = source.split('\n') lines = repair_saver_mcore_model_provider(lines) lines = repair_saver_mcore_get_reward(lines) source = '\n'.join(lines) elif module_name == 'utils': source = repair_mcore_block_key(source) else: raise RuntimeError(f"Unrecognized module_name {module_name}") return source def load_module(self, name): """ Load the module source code, fix the source and import manually :param name: :return: """ if name in sys.modules: return sys.modules[name] module_name = name.split('.')[-1] module_path = self.path[0] + '/' + module_name + '.py' # create the module spec object spec = importlib.util.spec_from_file_location(name, module_path) # read the source code and modify on-the-fly with open(module_path, encoding="utf-8") as f: source = f.read() new_source = self.repair_code(source, module_name) # create the module object based off the module spec module = importlib.util.module_from_spec(spec) # compile the source code into a code object where it # could be imported with `exec` call. codeobj = compile(new_source, module.__spec__.origin, 'exec') # module.__dict__ is required for referencing variables in the module exec(codeobj, module.__dict__) # pylint: disable=exec-used # put the loaded module into sys.modules so that if the module is imported # again it could be found. sys.modules[name] = module if ('loader_megatron' in name or 'saver_megatron' in name or 'loader_mcore' in name or 'saver_mcore' in name or 'loader_llama_mistral' in name): sys.modules[module_name] = module # return the module itself so that it could be used return module if __name__ == '__main__': if exist_checkpoint_util(): packages = ['tools.checkpoint.util', \ 'tools.checkpoint.loader_megatron', 'tools.checkpoint.saver_megatron'] else: packages = ['tools.checkpoint.convert', \ 'tools.checkpoint.loader_megatron', 'tools.checkpoint.saver_megatron', \ 'tools.checkpoint.loader_mcore', 'tools.checkpoint.saver_mcore', \ 'tools.checkpoint.utils', 'tools.checkpoint.loader_llama_mistral'] try: checkpoint_utils_import = CheckpointUtilsImporter(*packages) sys.meta_path.insert(0, checkpoint_utils_import) if exist_checkpoint_util(): from tools.checkpoint import loader_megatron, saver_megatron # pylint: disable=unused-import from tools.checkpoint import util util.main() else: from tools.checkpoint import loader_megatron, saver_megatron # pylint: disable=unused-import from tools.checkpoint import utils # pylint: disable=unused-import from tools.checkpoint import loader_mcore, saver_mcore # pylint: disable=unused-import from tools.checkpoint import loader_llama_mistral # pylint: disable=unused-import from tools.checkpoint import convert convert.main() finally: sys.meta_path.remove(checkpoint_utils_import) # pylint: enable=wildcard-import,exec-used