chatlearn/tools/megatron_checkpoint_utils.py (156 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Tools to convert megatron checkpoint with different parallel strategies.
Add support for reward model conversion.
"""
# pylint: disable=wildcard-import,exec-used
import sys
import importlib
def get_indent_count(string):
count = 0
for s in string:
if s == ' ':
count += 1
else:
return count
def repair_entry_file(source):
source = source.replace("choices=['GPT', 'BERT']", "choices=['GPT', 'BERT', 'REWARD']")
return source
def detect_and_insert_code(lines, pattern, new_code, additional_indent, line_offset):
type_line_number, type_line = [(line_number, line) for line_number, line in enumerate(lines) if pattern in line][0]
indent = get_indent_count(type_line) + additional_indent
new_lines = [line for line in new_code.split('\n') if line.strip()]
added_lines = []
for line in new_lines:
added_lines.append(" "*indent + line)
lines = lines[:type_line_number+line_offset] + added_lines + lines[type_line_number+line_offset:]
return lines
def repair_loader_model_provider(lines):
# Insert before following code, so line_offset=-2
# else:
# raise Exception(f'unrecognized model type: {args.model_type}')
pattern = 'unrecognized model type'
new_code = \
"""
elif args.model_type == 'REWARD':
from examples.megatron.models.reward_model import model_provider
margs.model_type = ModelType.encoder_or_decoder
"""
indent = -4
line_offset = -2
return detect_and_insert_code(lines, pattern, new_code, indent, line_offset)
def repair_saver_model_provider(lines):
return repair_loader_model_provider(lines)
def repair_loader_put_reward(lines):
pattern = 'queue.put("done")'
new_code = \
"""
if md.model_type == 'REWARD':
print("Sending LM Pooler")
message = {
"weight1": models[0].pooler_head.dense1.weight.data,
"bias1": models[0].pooler_head.dense1.bias.data,
"weight2": models[0].pooler_head.dense2.weight.data,
"bias2": models[0].pooler_head.dense2.bias.data,
}
queue_put("pooler_head", message)
"""
return detect_and_insert_code(lines, pattern, new_code, 0, 0)
def repair_saver_get_reward(lines):
pattern = 'if msg != "done":'
new_code = \
"""
if msg != "done" and msg["name"] == "pooler_head":
if not hasattr(models[0], 'pooler_head'):
print("ERROR: got a pooler_head, but model does not have one")
exit(1)
print("received pooler_head")
head_weight1 = msg.pop("weight1")
head_bias1 = msg.pop("bias1")
head_weight2 = msg.pop("weight2")
head_bias2 = msg.pop("bias2")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].pooler_head.dense1.weight.data.copy_(head_weight1)
models[tp_rank].pooler_head.dense1.bias.data.copy_(head_bias1)
models[tp_rank].pooler_head.dense2.weight.data.copy_(head_weight2)
models[tp_rank].pooler_head.dense2.bias.data.copy_(head_bias2)
check_message(msg)
msg = queue_get()
"""
return detect_and_insert_code(lines, pattern, new_code, 0, 0)
# MCore
def repair_mcore_block_key(source):
source = source.replace('"BERT" : "encoder",', '"BERT" : "encoder", "REWARD" : "decoder"')
return source
def repair_import_utils(source):
source = source.replace('from utils import get_mcore_transformer_block_key, print_memory_usage',
'from tools.checkpoint.utils import get_mcore_transformer_block_key, print_memory_usage')
return source
def repair_loader_mcore_import_error(source):
return repair_import_utils(source)
def repair_saver_mcore_import_error(source):
source = source.replace('from setter import ModelSetter',
'from tools.checkpoint.setter import ModelSetter')
return repair_import_utils(source)
def repair_loader_mcore_model_provider(lines):
# Insert before following code, so line_offset=-2
# else:
# raise Exception(f'unrecognized model type: {args.model_type}')
pattern = 'unrecognized model type'
new_code = \
"""
elif args.model_type == 'REWARD':
from examples.megatron.models.mcore_reward_model import model_provider
margs.model_type = ModelType.encoder_or_decoder
"""
indent = -4
line_offset = -2
return detect_and_insert_code(lines, pattern, new_code, indent, line_offset)
def repair_saver_mcore_model_provider(lines):
return repair_loader_mcore_model_provider(lines)
def repair_loader_mcore_put_reward(lines):
return repair_loader_put_reward(lines)
def repair_saver_mcore_get_reward(lines):
pattern = 'if msg != "done":'
new_code = \
"""
if msg != "done" and msg["name"] == "pooler_head":
if not hasattr(models[pp_rank][0][0], 'pooler_head'):
print("ERROR: got a pooler_head, but model does not have one")
exit(1)
print("received pooler_head")
head_weight1 = msg.pop("weight1")
head_bias1 = msg.pop("bias1")
head_weight2 = msg.pop("weight2")
head_bias2 = msg.pop("bias2")
for model in pp_local_models:
model.pooler_head.dense1.weight.data.copy_(head_weight1)
model.pooler_head.dense1.bias.data.copy_(head_bias1)
model.pooler_head.dense2.weight.data.copy_(head_weight2)
model.pooler_head.dense2.bias.data.copy_(head_bias2)
check_message(msg)
msg = queue_get()
"""
return detect_and_insert_code(lines, pattern, new_code, 0, 0)
def exist_checkpoint_util():
spec = importlib.util.find_spec('tools.checkpoint.util')
return spec is not None
def repair_loader_llama_mistral(source):
source = source.replace('args.seq_length = 4096', 'args.seq_length = model_args["max_position_embeddings"]')
return source
class CheckpointUtilsImporter:
"""CheckpointUtilsImporter"""
def __init__(self, *args):
self.module_names = args
self.path = None
def find_module(self, fullname, path=None):
if fullname in self.module_names:
# save the path so that it could be used later by `load_module`
self.path = path
return self
return None
def repair_code(self, source, module_name):
if module_name in ['util', 'convert']:
source = repair_entry_file(source)
elif module_name == 'loader_megatron':
lines = source.split('\n')
lines = repair_loader_model_provider(lines)
lines = repair_loader_put_reward(lines)
source = '\n'.join(lines)
elif module_name == 'saver_megatron':
lines = source.split('\n')
lines = repair_saver_model_provider(lines)
lines = repair_saver_get_reward(lines)
source = '\n'.join(lines)
elif module_name == 'loader_llama_mistral':
source = repair_loader_llama_mistral(source)
elif module_name == 'loader_mcore':
source = repair_loader_mcore_import_error(source)
lines = source.split('\n')
lines = repair_loader_mcore_model_provider(lines)
lines = repair_loader_mcore_put_reward(lines)
source = '\n'.join(lines)
elif module_name == 'saver_mcore':
source = repair_saver_mcore_import_error(source)
lines = source.split('\n')
lines = repair_saver_mcore_model_provider(lines)
lines = repair_saver_mcore_get_reward(lines)
source = '\n'.join(lines)
elif module_name == 'utils':
source = repair_mcore_block_key(source)
else:
raise RuntimeError(f"Unrecognized module_name {module_name}")
return source
def load_module(self, name):
"""
Load the module source code, fix the source and import manually
:param name:
:return:
"""
if name in sys.modules:
return sys.modules[name]
module_name = name.split('.')[-1]
module_path = self.path[0] + '/' + module_name + '.py'
# create the module spec object
spec = importlib.util.spec_from_file_location(name, module_path)
# read the source code and modify on-the-fly
with open(module_path, encoding="utf-8") as f:
source = f.read()
new_source = self.repair_code(source, module_name)
# create the module object based off the module spec
module = importlib.util.module_from_spec(spec)
# compile the source code into a code object where it
# could be imported with `exec` call.
codeobj = compile(new_source, module.__spec__.origin, 'exec')
# module.__dict__ is required for referencing variables in the module
exec(codeobj, module.__dict__) # pylint: disable=exec-used
# put the loaded module into sys.modules so that if the module is imported
# again it could be found.
sys.modules[name] = module
if ('loader_megatron' in name
or 'saver_megatron' in name
or 'loader_mcore' in name
or 'saver_mcore' in name
or 'loader_llama_mistral' in name):
sys.modules[module_name] = module
# return the module itself so that it could be used
return module
if __name__ == '__main__':
if exist_checkpoint_util():
packages = ['tools.checkpoint.util', \
'tools.checkpoint.loader_megatron', 'tools.checkpoint.saver_megatron']
else:
packages = ['tools.checkpoint.convert', \
'tools.checkpoint.loader_megatron', 'tools.checkpoint.saver_megatron', \
'tools.checkpoint.loader_mcore', 'tools.checkpoint.saver_mcore', \
'tools.checkpoint.utils', 'tools.checkpoint.loader_llama_mistral']
try:
checkpoint_utils_import = CheckpointUtilsImporter(*packages)
sys.meta_path.insert(0, checkpoint_utils_import)
if exist_checkpoint_util():
from tools.checkpoint import loader_megatron, saver_megatron # pylint: disable=unused-import
from tools.checkpoint import util
util.main()
else:
from tools.checkpoint import loader_megatron, saver_megatron # pylint: disable=unused-import
from tools.checkpoint import utils # pylint: disable=unused-import
from tools.checkpoint import loader_mcore, saver_mcore # pylint: disable=unused-import
from tools.checkpoint import loader_llama_mistral # pylint: disable=unused-import
from tools.checkpoint import convert
convert.main()
finally:
sys.meta_path.remove(checkpoint_utils_import)
# pylint: enable=wildcard-import,exec-used