in toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen1.5_moe.py [0:0]
def save_mgmodel(mgmodel, args):
args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
args.expert_model_parallel_size = args.target_expert_model_parallel_size
os.makedirs(args.save, exist_ok=True)
os.system("cp -rf " + args.load + "/config*.json " + args.save)
os.system("cp -rf " + args.load + "/tokenizer* " + args.save)
os.system("cp -rf " + args.load + "/vocab.json " + args.save)
os.system("cp -rf " + args.load + "/merges.txt " + args.save)
tracker_filepath = os.path.join(args.save, 'latest_checkpointed_iteration.txt')
with open(tracker_filepath, "w") as f:
f.write("release")
head_dim = args.hidden_size // args.num_attention_heads
group_per_split = args.num_attention_heads // args.tensor_model_parallel_size
full_model = mgmodel.state_dict_for_save_checkpoint()
for k in list(full_model.keys()):
if full_model[k] is None or "_extra_state" in k:
full_model.pop(k)
pattern = r'local_experts\.(\d+)\.'
num_local_experts = args.num_experts // args.expert_model_parallel_size if args.num_experts else 0
if (
args.tensor_model_parallel_size == 1
and args.pipeline_model_parallel_size == 1
and args.expert_model_parallel_size == 1
):
checkpoint_name = get_checkpoint_name(args.save, 0, True)
save_state_dict(args, full_model, checkpoint_name)
elif (
args.tensor_model_parallel_size == 1
and args.pipeline_model_parallel_size == 1
and args.expert_model_parallel_size >1
and args.num_experts % args.expert_model_parallel_size == 0
):
for ep_rank in range(args.expert_model_parallel_size):
model_split = {}
checkpoint_name = get_checkpoint_name(args.save, 0, True, None, None, None, True, ep_rank)
print(f'save ep_rank {ep_rank} model to {checkpoint_name}')
for k, v in full_model.items():
if 'local_experts' in k:
expert_rank = int(re.findall(pattern, k)[0])
if expert_rank // num_local_experts != ep_rank:
continue
expert_local_rank = expert_rank % num_local_experts
k = k.replace(f'local_experts.{expert_rank}', f'local_experts.{expert_local_rank}')
model_split[k] = v
save_state_dict(args, model_split, checkpoint_name)
elif (
args.tensor_model_parallel_size > 1
and args.pipeline_model_parallel_size == 1
and args.num_experts % args.expert_model_parallel_size == 0
):
for tp_rank in range(args.tensor_model_parallel_size):
for ep_rank in range(args.expert_model_parallel_size):
model_split = {}
if args.expert_model_parallel_size >1:
checkpoint_name = get_checkpoint_name(args.save, 0, True, None, tp_rank, None, True, ep_rank)
elif args.expert_model_parallel_size ==1:
checkpoint_name = get_checkpoint_name(args.save, 0, True, None, tp_rank, None, False)
for k, v in full_model.items():
if not isinstance(v, torch.Tensor):
target_v = v
elif 'linear_qkv.weight' in k and 'norm' not in k:
viewed = v.view(args.num_attention_heads, -1, head_dim, args.hidden_size)
viewed = viewed[group_per_split*tp_rank : group_per_split*(tp_rank + 1)]
target_v = viewed.view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k and 'norm' not in k:
viewed = v.view(args.num_attention_heads, -1, head_dim)
viewed = viewed[group_per_split*tp_rank : group_per_split*(tp_rank + 1)]
target_v = viewed.view(-1)
elif 'linear_proj' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg*tp_rank : seg*(tp_rank + 1)]
elif 'embedding' in k or 'output_layer' in k:
seg = v.shape[0] // args.tensor_model_parallel_size
target_v = v[seg*tp_rank : seg*(tp_rank + 1)]
elif 'local_experts' in k:
expert_rank = int(re.findall(pattern, k)[0])
if expert_rank // num_local_experts != ep_rank:
continue
expert_local_rank = expert_rank % num_local_experts
if 'linear_fc1' in k and 'norm' not in k:
viewed = v.view(-1, args.moe_ffn_hidden_size, args.hidden_size)
seg = args.moe_ffn_hidden_size // args.tensor_model_parallel_size
target_v = viewed[:, seg*tp_rank: seg*(tp_rank+1), :].reshape(-1, args.hidden_size)
elif 'linear_fc2' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg*tp_rank : seg*(tp_rank + 1)]
k = k.replace(f'local_experts.{expert_rank}', f'local_experts.{expert_local_rank}')
elif 'shared_expert' in k and 'gate' not in k:
if 'linear_fc1' in k:
viewed = v.view(-1, args.shared_moe_ffn_hidden_size, args.hidden_size)
seg = args.shared_moe_ffn_hidden_size // args.tensor_model_parallel_size
target_v = viewed[:, seg*tp_rank: seg*(tp_rank+1), :].reshape(-1, args.hidden_size)
elif 'linear_fc2' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg*tp_rank : seg*(tp_rank + 1)]
else:
target_v = v
model_split[k] = target_v
save_state_dict(args, model_split, checkpoint_name)
else:
raise ValueError('not support pp convert')
print(f'megatron model is save to {args.save}')