in toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen2_moe.py [0:0]
def save_mgmodel(mgmodel: GPTModel, args):
args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
args.expert_model_parallel_size = args.target_expert_model_parallel_size
args.expert_tensor_parallel_size = args.target_expert_tensor_parallel_size
if args.num_experts is not None:
args.expert_model_parallel_size = args.target_expert_model_parallel_size
os.makedirs(args.save, exist_ok=True)
os.system("cp -rf " + args.load + "/*config.json " + args.save)
os.system("cp -rf " + args.load + "/tokenizer* " + args.save)
os.system("cp -rf " + args.load + "/merges.txt " + args.save)
os.system("cp -rf " + args.load + "/vocab.json " + args.save)
tracker_filepath = os.path.join(args.save, 'latest_checkpointed_iteration.txt')
with open(tracker_filepath, "w") as f:
f.write("release")
head_dim = args.hidden_size // args.num_attention_heads if args.kv_channels is None else args.kv_channels
group_per_split = args.num_query_groups // args.target_tensor_model_parallel_size
full_model = mgmodel.state_dict_for_save_checkpoint()
for k in list(full_model.keys()):
if full_model[k] is None and '_extra_state' not in k:
full_model.pop(k)
continue
if '_extra_state' in k and isinstance(full_model[k], torch.Tensor):
full_model[k] = None
if args.num_experts is not None:
pattern = r'weight(\d+)'
assert args.num_experts % args.expert_model_parallel_size == 0
num_local_experts = args.num_experts // args.expert_model_parallel_size if args.num_experts else 0
if args.target_decoder_first_pipeline_num_layers is not None:
remained_layers = args.num_layers - args.target_decoder_first_pipeline_num_layers
remained_stages = args.pipeline_model_parallel_size - 1
assert remained_layers % remained_stages == 0
pp_layers_per_stage = [ args.target_decoder_first_pipeline_num_layers] +([remained_layers // remained_stages] * remained_stages)
else:
pp_layers_per_stage = [args.num_layers // args.pipeline_model_parallel_size] * args.pipeline_model_parallel_size
for (tp_rank, etp_rank, ep_rank, pp_rank) in generate_rank_group(
args.tensor_model_parallel_size,
args.expert_tensor_parallel_size,
args.expert_model_parallel_size,
args.pipeline_model_parallel_size
):
model_split = {}
layer_offset = sum(pp_layers_per_stage[:pp_rank])
layers_to_copy = {}
for layer in range(pp_layers_per_stage[pp_rank]):
pp_layer_id = layer + layer_offset
layers_to_copy[f"decoder.layers.{pp_layer_id}"] = layer
checkpoint_name = get_checkpoint_name(
args.save, 0, True,
args.pipeline_model_parallel_size > 1,
tp_rank,
pp_rank,
args.expert_model_parallel_size > 1,
ep_rank
)
print(f'tensor_parallel & pipeline_parallel & expert_parallel, save model to {checkpoint_name}')
for k, v in full_model.items():
if check_layer(layers_to_copy, k):
layer_pattern = re.compile(r'\d+')
res = layer_pattern.findall(k)
k = re.sub(r"decoder.layers.\d+", "decoder.layers." + str(layers_to_copy["decoder.layers." + res[0]]), k)
elif not ("word_embeddings" in k or "output_layer" in k or "final_layernorm" in k):
continue
if not isinstance(v, torch.Tensor):
target_v = v
elif 'linear_qkv.weight' in k:
viewed = v.view(args.num_query_groups, -1, head_dim, args.hidden_size)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = v.view(args.num_query_groups, -1, head_dim)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1)
elif 'linear_proj' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
elif 'experts' in k and 'shared_experts' not in k:
expert_rank = int(re.findall(pattern, k)[0])
if expert_rank // num_local_experts != ep_rank:
continue
expert_local_rank = expert_rank % num_local_experts
k = k.replace(f'weight{expert_rank}', f'weight{expert_local_rank}')
if 'linear_fc1' in k:
viewed = v.view(-1, args.moe_ffn_hidden_size, args.hidden_size)
seg = args.moe_ffn_hidden_size // args.expert_tensor_parallel_size
target_v = viewed[:, seg * etp_rank: seg * (etp_rank + 1), :].reshape(-1, args.hidden_size)
elif 'linear_fc2' in k:
target_v = split_row_parallel(v, etp_rank, args.expert_tensor_parallel_size)
else:
raise NotImplementedError
elif 'shared_experts' in k and 'gate' not in k:
if 'linear_fc1' in k:
viewed = v.view(-1, args.moe_shared_expert_intermediate_size,
args.hidden_size)
seg = args.moe_shared_expert_intermediate_size // args.tensor_model_parallel_size
target_v = viewed[:, seg * tp_rank: seg * (tp_rank + 1), :].reshape(-1, args.hidden_size)
elif 'linear_fc2' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg * tp_rank: seg * (tp_rank + 1)]
elif "word_embeddings" in k or "output_layer" in k:
seg = v.shape[0] // args.tensor_model_parallel_size
target_v = v[seg * tp_rank: seg * (tp_rank + 1)]
else:
target_v = v
if "word_embeddings" in k:
if pp_rank == 0:
model_split[k] = target_v
elif "output_layer" in k or "final_layernorm" in k:
if pp_rank == args.pipeline_model_parallel_size - 1:
model_split[k] = target_v
else:
model_split[k] = target_v
save_state_dict(args, [model_split], checkpoint_name)
print(f'megatron model is save to {args.save}')