toolkits/model_checkpoints_convertor/llama/hf2mcore.py (674 lines of code) (raw):
import os
import re
import json
import torch
import transformers
import torch.nn as nn
from functools import partial
from collections import defaultdict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint, load_sharded_checkpoint
from megatron.initialize import initialize_megatron
from megatron import get_args
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.checkpointing import get_checkpoint_name, get_checkpoint_tracker_filename, read_metadata
import sys
path_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))
sys.path.append(os.path.join(path_dir, "examples"))
from llama2.pretrain_mcore_llama import model_provider
from megatron_patch.arguments import get_patch_args
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
import numpy as np
from collections.abc import Mapping, Sequence
@torch.inference_mode()
def clone_state_dict(elem):
"""clone all tensors in the elem to cpu device.
"""
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
elem = elem.clone()
elif isinstance(elem, (np.ndarray, str)):
pass
elif isinstance(elem, Mapping):
elem = dict(elem)
for k, v in elem.items():
elem[k] = clone_state_dict(v)
elem = elem_type(elem)
elif isinstance(elem, Sequence):
elem = list(elem)
for i in range(len(elem)):
elem[i] = clone_state_dict(elem[i])
elem = elem_type(elem)
return elem
def add_checkpointing_args(parser):
parser.add_argument('--megatron-path',
type=str,
default=None,
help='Base directory of Megatron repository')
parser.add_argument(
'--convert_checkpoint_from_megatron_to_transformers',
action='store_true',
help=
('If True, convert a Megatron checkpoint to a Transformers checkpoint. '
'If False, convert a Transformers checkpoint to a Megatron checkpoint.'
),
)
parser.add_argument(
'--load_path',
type=str,
required=True,
help='Path to the checkpoint to convert.',
)
parser.add_argument(
'--save_path',
type=str,
required=True,
help='Path to the converted checkpoint.',
)
parser.add_argument(
'--huggingface_model_path',
type=str,
required=
True,
)
return parser
def add_megatron_checkpoint_args(parser):
parser.add_argument(
"--target_tensor_model_parallel_size",
type=int,
default=1,
help=(
"The tensor model parallel size of the converted checkpoint. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--target_pipeline_model_parallel_size",
type=int,
default=1,
help=(
"The pipeline model parallel size of the converted checkpoint. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--target_expert_model_parallel_size",
type=int,
default=1,
help=(
"The data parallel size of the converted checkpoint. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--num_expert_split_size",
type=int,
default=1
)
return parser
def add_transformers_checkpoint_args(parser):
parser.add_argument(
"--max_shard_size",
type=str,
default="10GB",
help=(
"The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size "
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). "
"Only used when converting a Megatron checkpoint to a Transformers checkpoint."
),
)
return parser
def build_huggingface_model(model_to_load, compute_dtype, random_init=False):
config = AutoConfig.from_pretrained(
model_to_load,
trust_remote_code=True,
)
if random_init:
model = AutoModelForCausalLM.from_config(
config=config,
torch_dtype=compute_dtype,
trust_remote_code=True
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
torch_dtype=compute_dtype,
trust_remote_code=True
)
return config, model.eval()
def replace_mlp_with_moe(args, model):
config = MixtralConfig(
intermediate_size=args.intermediate_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
num_local_experts=args.num_local_experts,
num_key_value_heads=args.num_key_value_heads,
rope_theta=args.rope_theta,
rms_norm_eps=args.rms_norm_eps,
num_experts_per_tok=1,
)
def get_hidden_output(module, args, output):
return output[0]
for layer in model.model.layers:
mlp = MixtralSparseMoeBlock(config).to(args.torch_dtype)
mlp.register_forward_hook(get_hidden_output)
layer.mlp = mlp
return model
def create_huggingface_model(args):
if not args.convert_checkpoint_from_megatron_to_transformers or args.num_experts is None:
copy_huggingface_tokenizer(args.huggingface_model_path, args.save_path)
config, model = build_huggingface_model(args.huggingface_model_path, args.params_dtype)
else:
copy_huggingface_tokenizer(args.huggingface_model_path, args.save_path, with_code=True)
config, model = build_huggingface_model(args.save_path, args.params_dtype, random_init=True)
model = replace_mlp_with_moe(config, model)
return config, model.eval()
def create_megatron_model(args, hf_config):
args.hidden_size = hf_config.hidden_size
args.num_layers = hf_config.num_hidden_layers
args.num_attention_heads = hf_config.num_attention_heads
args.kv_channels = args.hidden_size // args.num_attention_heads
if not args.convert_checkpoint_from_megatron_to_transformers:
if args.num_expert_split_size == 1:
args.ffn_hidden_size = hf_config.intermediate_size
else:
args.ffn_hidden_size = hf_config.intermediate_size // args.num_expert_split_size
else:
args.ffn_hidden_size = hf_config.intermediate_size
args.num_query_groups = hf_config.num_key_value_heads
model = model_provider()
return model.eval()
def copy_huggingface_tokenizer(src_path, dst_path, with_code=False):
assert os.path.exists(src_path)
os.makedirs(dst_path, exist_ok=True)
os.system("cp -rf " + src_path + "/config*.json " + dst_path)
os.system("cp -rf " + src_path + "/tokenizer* " + dst_path)
if with_code:
cur_dir = os.path.dirname(os.path.abspath(__file__))
code_path = os.path.join(cur_dir, 'hf_llama_moe')
os.system("cp -rf " + code_path + "/*.py " + dst_path)
os.system("cp -rf " + code_path + "/*.json " + dst_path)
def name_to_expert_rank(key):
pattern = r'local_experts\.(\d+)\.'
expert_rank = int(re.findall(pattern, key)[0])
return expert_rank
def load_megatron_model(args, model):
model_path = args.load_path
tracker_filename = get_checkpoint_tracker_filename(model_path)
iteration, release = read_metadata(tracker_filename)
head_dim = args.hidden_size // args.num_attention_heads
group_per_split = args.num_query_groups // args.target_tensor_model_parallel_size
num_local_experts = args.num_experts // args.target_expert_model_parallel_size if args.num_experts else 0
state_dict = {}
mid_state = defaultdict(list)
if (
args.target_tensor_model_parallel_size == 1
and args.target_pipeline_model_parallel_size == 1
and args.target_expert_model_parallel_size == 1
):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, None, None, None, None)
state_dict = torch.load(checkpoint_name)['model']
elif (
args.target_tensor_model_parallel_size == 1
and args.target_pipeline_model_parallel_size == 1
and args.num_experts
and args.num_experts % args.target_expert_model_parallel_size == 0
):
for ep_rank in range(args.target_expert_model_parallel_size):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, None, None, True, ep_rank)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu")['model']
for k, v in split_state.items():
if 'local_experts' in k:
expert_local_rank = name_to_expert_rank(k)
expert_rank = expert_local_rank + num_local_experts * ep_rank
k = k.replace(f'local_experts.{expert_local_rank}', f'local_experts.{expert_rank}')
state_dict[k] = v
elif (
args.target_tensor_model_parallel_size > 1
and args.target_pipeline_model_parallel_size == 1
and args.num_experts is None
):
for tp_rank in range(args.target_tensor_model_parallel_size):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, tp_rank, None, None, None)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu")['model']
for k, v in split_state.items():
mid_state[k].append(v)
for k, v in mid_state.items():
if not isinstance(v[0], torch.Tensor) or 'norm' in k:
target_v = v[0]
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
else:
raise ValueError
state_dict[k] = target_v
elif (
args.target_tensor_model_parallel_size > 1
and args.target_pipeline_model_parallel_size == 1
and args.num_experts
and args.num_experts % args.target_expert_model_parallel_size == 0
):
for tp_rank in range(args.target_tensor_model_parallel_size):
for ep_rank in range(args.target_expert_model_parallel_size):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, tp_rank, None, True,
ep_rank)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu")['model']
for k, v in split_state.items():
if 'local_experts' in k and 'norm' not in k:
local_expert_rank = name_to_expert_rank(k)
expert_rank = local_expert_rank + num_local_experts * ep_rank
k = k.replace(f'local_experts.{local_expert_rank}', f'local_experts.{expert_rank}')
mid_state[k].append(v)
elif ep_rank == 0:
mid_state[k].append(v)
for k, v in mid_state.items():
if not isinstance(v[0], torch.Tensor) or 'norm' in k or 'router' in k:
target_v = v[0]
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
else:
print('passed', k)
state_dict[k] = target_v
else:
raise ValueError('not support yet')
model.load_state_dict(state_dict)
return model
def convert_checkpoint_from_megatron_to_transformers(mgmodel, hgmodel, args):
query_group = args.num_query_groups
hidden_size = args.hidden_size
head_dim = hidden_size // args.num_attention_heads
num_experts = args.num_experts
value_num_per_group = args.num_attention_heads // query_group
with torch.no_grad():
hgmodel.model.embed_tokens.weight.copy_(mgmodel.embedding.word_embeddings.weight)
for mglayer, hglayer in zip(mgmodel.decoder.layers, hgmodel.model.layers):
hglayer.input_layernorm.weight.copy_(mglayer.self_attention.linear_qkv.layer_norm_weight)
qkv_weight = mglayer.self_attention.linear_qkv.weight.view(query_group, -1, head_dim, hidden_size)
q_weight, k_weight, v_weight = torch.split(qkv_weight, split_size_or_sections=[value_num_per_group, 1, 1],
dim=1)
hglayer.self_attn.q_proj.weight.copy_(q_weight.reshape(-1, hidden_size))
hglayer.self_attn.k_proj.weight.copy_(k_weight.reshape(-1, hidden_size))
hglayer.self_attn.v_proj.weight.copy_(v_weight.reshape(-1, hidden_size))
hglayer.self_attn.o_proj.weight.copy_(mglayer.self_attention.linear_proj.weight)
if num_experts is None:
gate_weight, fc1_weight = torch.split(mglayer.mlp.linear_fc1.weight,
split_size_or_sections=args.ffn_hidden_size)
hglayer.mlp.gate_proj.weight.copy_(gate_weight)
hglayer.mlp.up_proj.weight.copy_(fc1_weight)
hglayer.mlp.down_proj.weight.copy_(mglayer.mlp.linear_fc2.weight)
hglayer.post_attention_layernorm.weight.copy_(mglayer.mlp.linear_fc1.layer_norm_weight)
else:
hglayer.post_attention_layernorm.weight.copy_(mglayer.pre_mlp_layernorm.weight)
hglayer.mlp.gate.weight.copy_(mglayer.mlp.router.weight)
if args.num_expert_split_size == 1:
for mgexpert, hgexpert in zip(mglayer.mlp.experts.local_experts, hglayer.mlp.experts):
gate_weight, fc1_weight = torch.split(mgexpert.linear_fc1.weight,
split_size_or_sections=args.ffn_hidden_size)
hgexpert.w1.weight.copy_(gate_weight)
hgexpert.w3.weight.copy_(fc1_weight)
hgexpert.w2.weight.copy_(mgexpert.linear_fc2.weight)
else:
for mgexpert, hgexpert in zip(mglayer.mlp.experts.local_experts, hglayer.mlp.experts):
gate_weight, fc1_weight = torch.split(mgexpert.linear_fc1.weight,
split_size_or_sections=args.ffn_hidden_size)
hgexpert.w1.weight.copy_(gate_weight)
hgexpert.w3.weight.copy_(fc1_weight)
hgexpert.w2.weight.copy_(mgexpert.linear_fc2.weight)
hgmodel.model.norm.weight.copy_(mgmodel.decoder.final_layernorm.weight)
hgmodel.lm_head.weight.copy_(mgmodel.output_layer.weight)
def convert_checkpoint_from_transformers_to_megatron(mgmodel, hgmodel, args, hf_config):
num_query_groups = hf_config.num_key_value_heads
hidden_dim = hf_config.hidden_size
head_dim = hidden_dim // hf_config.num_attention_heads
num_experts = args.num_experts
with torch.no_grad():
mgmodel.embedding.word_embeddings.weight.copy_(hgmodel.model.embed_tokens.weight)
for mglayer, hglayer in zip(mgmodel.decoder.layers, hgmodel.model.layers):
mglayer.self_attention.linear_qkv.layer_norm_weight.copy_(hglayer.input_layernorm.weight)
q = hglayer.self_attn.q_proj.weight.view([num_query_groups, -1, head_dim, hidden_dim])
k = hglayer.self_attn.k_proj.weight.view([num_query_groups, -1, head_dim, hidden_dim])
v = hglayer.self_attn.v_proj.weight.view([num_query_groups, -1, head_dim, hidden_dim])
qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous()
mglayer.self_attention.linear_qkv.weight.copy_(qkv)
mglayer.self_attention.linear_proj.weight.copy_(hglayer.self_attn.o_proj.weight)
fc1_weight = torch.cat([hglayer.mlp.gate_proj.weight, hglayer.mlp.up_proj.weight])
if num_experts is None:
mglayer.mlp.linear_fc1.weight.copy_(fc1_weight)
mglayer.mlp.linear_fc2.weight.copy_(hglayer.mlp.down_proj.weight)
mglayer.mlp.linear_fc1.layer_norm_weight.copy_(hglayer.post_attention_layernorm.weight)
else:
if args.num_expert_split_size == 1:
mglayer.pre_mlp_layernorm.weight.copy_(hglayer.post_attention_layernorm.weight)
nn.init.normal_(mglayer.mlp.router.weight, mean=0, std=0.02)
for expert in mglayer.mlp.experts.local_experts:
expert.linear_fc1.weight.copy_(fc1_weight)
expert.linear_fc2.weight.copy_(hglayer.mlp.down_proj.weight)
else:
mglayer.pre_mlp_layernorm.weight.copy_(hglayer.post_attention_layernorm.weight)
nn.init.normal_(mglayer.mlp.router.weight, mean=0, std=0.02)
split_size = hf_config.intermediate_size // args.num_expert_split_size
gate_proj_splits = torch.split(hglayer.mlp.gate_proj.weight, split_size_or_sections=split_size)
up_proj_splits = torch.split(hglayer.mlp.up_proj.weight, split_size_or_sections=split_size)
down_proj_splits = torch.split(hglayer.mlp.down_proj.weight, split_size_or_sections=split_size, dim=1)
for idx, expert in enumerate(mglayer.mlp.experts.local_experts):
expert.linear_fc1.weight.copy_(torch.cat([gate_proj_splits[idx], up_proj_splits[idx]]))
expert.linear_fc2.weight.copy_(down_proj_splits[idx])
"""
for idx, expert in enumerate(mglayer.mlp.experts.local_experts):
base_linear_fc1 = torch.cat([gate_proj_splits[idx], up_proj_splits[idx]])
extra_linear_fc1 = torch.empty(32, base_linear_fc1.shape[1])
extra_linear_fc2 = torch.empty(base_linear_fc1.shape[1], 16)
nn.init.normal_(extra_linear_fc1, mean=0, std=0.02)
nn.init.normal_(extra_linear_fc2, mean=0, std=0.02)
expert.linear_fc1.weight.copy_(torch.cat([base_linear_fc1, extra_linear_fc1.to(torch.float16)]))
expert.linear_fc2.weight.copy_(torch.cat([down_proj_splits[idx], extra_linear_fc2.to(torch.float16)], dim=1))
"""
mgmodel.decoder.final_layernorm.weight.copy_(hgmodel.model.norm.weight)
mgmodel.output_layer.weight.copy_(hgmodel.lm_head.weight)
def save_state_dict(args, model, checkpoint_name):
args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = 0
state_dict['model'] = model
os.makedirs(os.path.dirname(checkpoint_name), exist_ok=True)
print(f'save model part {checkpoint_name}')
torch.save(clone_state_dict(state_dict), checkpoint_name)
def save_mgmodel(args, mgmodel, load_path, save_path):
# Saving config and tokenzier files
copy_huggingface_tokenizer(load_path, save_path)
tracker_filepath = os.path.join(save_path, 'latest_checkpointed_iteration.txt')
with open(tracker_filepath, "w") as f:
f.write("release")
head_dim = args.hidden_size // args.num_attention_heads
group_per_split = args.num_query_groups // args.target_tensor_model_parallel_size
full_model = mgmodel.state_dict_for_save_checkpoint()
for k in list(full_model.keys()):
if full_model[k] is None or "_extra_state" in k:
full_model.pop(k)
pattern = r'local_experts\.(\d+)\.'
num_local_experts = args.num_experts // args.target_expert_model_parallel_size if args.num_experts else 0
if (
args.target_tensor_model_parallel_size == 1
and args.target_pipeline_model_parallel_size == 1
and args.target_expert_model_parallel_size == 1
):
checkpoint_name = get_checkpoint_name(save_path, 0, True)
save_state_dict(args, full_model, checkpoint_name)
elif (
args.target_tensor_model_parallel_size == 1
and args.target_pipeline_model_parallel_size == 1
and args.num_experts
and args.num_experts % args.target_expert_model_parallel_size == 0
):
for ep_rank in range(args.target_expert_model_parallel_size):
model_split = {}
checkpoint_name = get_checkpoint_name(save_path, 0, True, None, None, None, True, ep_rank)
print(f'save ep_rank {ep_rank} model to {checkpoint_name}')
for k, v in full_model.items():
if 'local_experts' in k:
expert_rank = int(re.findall(pattern, k)[0])
if expert_rank // num_local_experts != ep_rank:
continue
expert_local_rank = expert_rank % args.target_expert_model_parallel_size
k = k.replace(f'local_experts.{expert_rank}', f'local_experts.{expert_local_rank}')
model_split[k] = v
save_state_dict(args, model_split, checkpoint_name)
elif (
args.target_tensor_model_parallel_size > 1
and args.target_pipeline_model_parallel_size == 1
and args.num_experts is None
):
for tp_rank in range(args.target_tensor_model_parallel_size):
model_split = {}
checkpoint_name = get_checkpoint_name(save_path, 0, True, None, tp_rank)
print(f'tensor_parallel, save model to {checkpoint_name}')
for k, v in full_model.items():
if not isinstance(v, torch.Tensor):
target_v = v
elif 'linear_qkv.weight' in k and 'norm' not in k:
viewed = v.view(args.num_query_groups, -1, head_dim, args.hidden_size)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k and 'norm' not in k:
viewed = v.view(args.num_query_groups, -1, head_dim)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1)
elif 'linear_proj' in k or 'linear_fc2' in k:
seg = v.shape[1] // args.target_tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
elif 'embedding' in k or 'output_layer' in k:
seg = v.shape[0] // args.target_tensor_model_parallel_size
target_v = v[seg * tp_rank: seg * (tp_rank + 1)]
elif 'linear_fc1' in k and 'norm' not in k:
viewed = v.view(-1, args.ffn_hidden_size, args.hidden_size)
seg = args.ffn_hidden_size // args.target_tensor_model_parallel_size
target_v = viewed[:, seg * tp_rank: seg * (tp_rank + 1), :].reshape(-1, args.hidden_size)
else:
target_v = v
model_split[k] = target_v
save_state_dict(args, model_split, checkpoint_name)
elif (
args.target_tensor_model_parallel_size > 1
and args.target_pipeline_model_parallel_size == 1
and args.num_experts
and args.num_experts % args.target_expert_model_parallel_size == 0
):
for tp_rank in range(args.target_tensor_model_parallel_size):
for ep_rank in range(args.target_expert_model_parallel_size):
model_split = {}
checkpoint_name = get_checkpoint_name(save_path, 0, True, None, tp_rank, None, True, ep_rank)
for k, v in full_model.items():
if not isinstance(v, torch.Tensor):
target_v = v
elif 'linear_qkv.weight' in k and 'norm' not in k:
viewed = v.view(args.num_query_groups, -1, head_dim, args.hidden_size)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k and 'norm' not in k:
viewed = v.view(args.num_query_groups, -1, head_dim)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1)
elif 'linear_proj' in k:
seg = v.shape[1] // args.target_tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
elif 'embedding' in k or 'output_layer' in k:
seg = v.shape[0] // args.target_tensor_model_parallel_size
target_v = v[seg * tp_rank: seg * (tp_rank + 1)]
elif 'local_experts' in k:
expert_rank = int(re.findall(pattern, k)[0])
if expert_rank // num_local_experts != ep_rank:
continue
expert_local_rank = expert_rank % num_local_experts
if 'linear_fc1' in k and 'norm' not in k:
viewed = v.view(-1, args.ffn_hidden_size, args.hidden_size)
seg = args.ffn_hidden_size // args.target_tensor_model_parallel_size
target_v = viewed[:, seg * tp_rank: seg * (tp_rank + 1), :].reshape(-1, args.hidden_size)
elif 'linear_fc2' in k:
seg = v.shape[1] // args.target_tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
k = k.replace(f'local_experts.{expert_rank}', f'local_experts.{expert_local_rank}')
else:
target_v = v
model_split[k] = target_v
save_state_dict(args, model_split, checkpoint_name)
else:
raise ValueError('not support pp convert')
print(f'megatron model is save to {save_path}')
def save_hgmodel(args, model):
output_state_dict = model.state_dict()
max_shard_size = args.max_shard_size
shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size)
os.makedirs(args.save_path, exist_ok=True)
for shard_file, shard in shards.items():
target_file = os.path.join(args.save_path, shard_file)
print(f'huggingface model is save to {target_file}')
torch.save(clone_state_dict(shard), target_file)
if index is None:
print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
print(
f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def check_mg_eg_forward(mgmodel, hgmodel, mgargs):
hg_hiddens = [{} for _ in range(mgargs.num_layers)]
mg_hiddens = [{} for _ in range(mgargs.num_layers)]
head_dim = mgargs.hidden_size // mgargs.num_attention_heads
hidden_size = mgargs.hidden_size
def print_input_hook(module, args, kwargs, layer_idx, mode):
frame, name = mode.split('-')
if frame == 'hg':
hg_hiddens[layer_idx][name] = args[0].transpose(0, 1)
elif frame == 'mg' and 'layer' in mode:
mg_hiddens[layer_idx][name] = kwargs.get('hidden_states')
elif frame == 'mg':
mg_hiddens[layer_idx][name] = args[0]
def print_output_hook(module, args, kwargs, output, layer_idx, mode):
frame, name = mode.split('-')
if mode in ['hg-q_proj_out', 'hg-k_proj_out', 'hg-v_proj_out']:
hg_hiddens[layer_idx][name] = output
hg_hiddens[layer_idx][name + '_weight'] = module.weight
elif mode in ['hg-lmhead']:
hg_hiddens[layer_idx][name] = output.transpose(0, 1)
hg_hiddens[layer_idx][name + '_token'] = output.transpose(0, 1).max(dim=-1)[1]
print(output.transpose(0, 1).max(dim=-1))
elif mode == 'hg-attn_out':
hg_hiddens[layer_idx][name] = output[0].transpose(0, 1)
elif mode in ['mg-lmhead']:
mg_hiddens[layer_idx][name] = output[0]
mg_hiddens[layer_idx][name + '_token'] = output[0].max(dim=-1)[1]
print(output[0].max(dim=-1))
elif mode == 'mg-attn_out':
mg_hiddens[layer_idx][name] = output[0]
elif mode == 'mg-qkv':
mixed_qkv = output[0]
sq, b, _ = mixed_qkv.shape
mixed_qkv = mixed_qkv.view(sq, b, mgargs.num_query_groups, -1)
qh = mgargs.num_attention_heads // mgargs.num_query_groups
qo, ko, vo = torch.split(mixed_qkv, [qh * head_dim, head_dim, head_dim], dim=3)
qo = qo.reshape(b, -1, hidden_size)
ko = ko.reshape(b, -1, hidden_size // qh)
vo = vo.reshape(b, -1, hidden_size // qh)
mg_hiddens[layer_idx]['q_proj_out'] = qo
mg_hiddens[layer_idx]['k_proj_out'] = ko
mg_hiddens[layer_idx]['v_proj_out'] = vo
weight = module.weight.view(mgargs.num_query_groups, -1, head_dim, hidden_size)
qw, kw, vw = weight.split([qh, 1, 1], dim=1)
mg_hiddens[layer_idx]['q_proj_out_weight'] = qw.reshape(-1, hidden_size)
mg_hiddens[layer_idx]['k_proj_out_weight'] = kw.reshape(-1, hidden_size // qh)
mg_hiddens[layer_idx]['v_proj_out_weight'] = vw.reshape(-1, hidden_size // qh)
hgmodel.lm_head.register_forward_hook(partial(print_output_hook, layer_idx=mgargs.num_layers - 1, mode='hg-lmhead'),
with_kwargs=True)
mgmodel.output_layer.register_forward_hook(
partial(print_output_hook, layer_idx=mgargs.num_layers - 1, mode='mg-lmhead'), with_kwargs=True)
for idx, layer in enumerate(hgmodel.model.layers):
layer.register_forward_pre_hook(partial(print_input_hook, layer_idx=idx, mode='hg-layer_in'), with_kwargs=True)
layer.self_attn.o_proj.register_forward_pre_hook(partial(print_input_hook, layer_idx=idx, mode='hg-o_proj_in'),
with_kwargs=True)
layer.self_attn.q_proj.register_forward_hook(partial(print_output_hook, layer_idx=idx, mode='hg-q_proj_out'),
with_kwargs=True)
layer.self_attn.k_proj.register_forward_hook(partial(print_output_hook, layer_idx=idx, mode='hg-k_proj_out'),
with_kwargs=True)
layer.self_attn.v_proj.register_forward_hook(partial(print_output_hook, layer_idx=idx, mode='hg-v_proj_out'),
with_kwargs=True)
layer.self_attn.register_forward_hook(partial(print_output_hook, layer_idx=idx, mode='hg-attn_out'),
with_kwargs=True)
for idx, layer in enumerate(mgmodel.decoder.layers):
layer.register_forward_pre_hook(partial(print_input_hook, layer_idx=idx, mode='mg-layer_in'), with_kwargs=True)
layer.self_attention.linear_qkv.register_forward_hook(partial(print_output_hook, layer_idx=idx, mode='mg-qkv'),
with_kwargs=True)
layer.self_attention.linear_proj.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode='mg-o_proj_in'), with_kwargs=True)
layer.self_attention.register_forward_hook(partial(print_output_hook, layer_idx=idx, mode='mg-attn_out'),
with_kwargs=True)
input_ids = torch.tensor([[1, 2, 3]]).long().cuda()
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True)
with torch.inference_mode():
try:
hgmodel.cuda()
hgmodel(input_ids=input_ids)
except torch.cuda.OutOfMemoryError:
print('oom for huggingface model forward')
hgmodel.cpu()
del hgmodel
with torch.inference_mode():
try:
mgmodel.cuda()
mgmodel(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
except torch.cuda.OutOfMemoryError:
print('oom for megatron model forward')
mgmodel.cpu()
del mgmodel
epsilon = 1e-5
for idx, (hgh, mgh) in enumerate(zip(hg_hiddens, mg_hiddens)):
if len(hgh) != len(mgh):
continue
for k, hgv in hgh.items():
mgv, hgv = mgh[k].cpu(), hgv.cpu()
same_num = (hgv != mgv).sum()
diff_num = ((hgv - mgv) > epsilon).sum()
diff_max = (hgv - mgv).abs().max()
print(f'layer:{idx}, {k}, diff: {same_num}, diff>{epsilon}:[{diff_num}/{hgv.numel()}] diff_max:{diff_max}')
def check_tokenizer_is_same(hgtokenizer, mgtokenizer):
if transformers.__version__ <= '4.33.2':
print('please update transformers')
return
if mgtokenizer is None:
return
conversation = [
{"role": "user", "content": "what's your name"},
{"role": "bot", "content": "cold"},
]
hgres = hgtokenizer.apply_chat_template(conversation)
mgres = mgtokenizer.apply_chat_template(conversation)
for x, y in zip(hgres, mgres):
assert x == y, 'tokenizer is different for huggingface and megatron'
def add_ckpt_args(parser):
parser = get_patch_args(parser)
parser = add_checkpointing_args(parser)
parser = add_megatron_checkpoint_args(parser)
parser = add_transformers_checkpoint_args(parser)
return parser
def main():
initialize_megatron(extra_args_provider=add_ckpt_args)
args = get_args()
hf_config, hf_model = create_huggingface_model(args)
mg_model = create_megatron_model(args, hf_config)
if args.convert_checkpoint_from_megatron_to_transformers:
load_megatron_model(args, mg_model)
convert_checkpoint_from_megatron_to_transformers(mg_model, hf_model, args)
# check_mg_eg_forward(mg_model, hf_model, args)
save_hgmodel(args, hf_model)
else:
hf_model.from_pretrained(args.load_path)
convert_checkpoint_from_transformers_to_megatron(mg_model, hf_model, args, hf_config)
# check_mg_eg_forward(mg_model, hf_model, args)
save_mgmodel(args, mg_model, args.load_path, args.save_path)
if __name__ == "__main__":
main()