in toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen2_vl.py [0:0]
def save_mgmodel(mgmodel, args):
args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
vpp_size = 1 # NOTE: vpp_size=1 if vpp is not used
if args.target_num_layers_per_virtual_pipeline_stage is not None:
args.num_layers_per_virtual_pipeline_stage = args.target_num_layers_per_virtual_pipeline_stage
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
vpp_size = args.virtual_pipeline_model_parallel_size
os.makedirs(args.save, exist_ok=True)
tracker_filepath = os.path.join(args.save, 'latest_checkpointed_iteration.txt')
with open(tracker_filepath, "w") as f:
f.write("release")
head_dim = args.hidden_size // args.num_attention_heads
group_per_split = args.num_query_groups // args.target_tensor_model_parallel_size
full_model = mgmodel.state_dict_for_save_checkpoint()
for k in list(full_model.keys()):
if 'extra_state' in k:
# NOTE: since TE 1.14, fp8 metadata will be saved as tensor.
# Always drop these values in the MG ckpt to avoid potential issue.
# This should work fine because fp8 metadata is not supported by HF ckpt.
full_model[k] = None
elif full_model[k] is None:
full_model.pop(k)
if (
args.tensor_model_parallel_size == 1
and args.pipeline_model_parallel_size == 1
):
checkpoint_name = get_checkpoint_name(args.save, 0, True)
save_state_dict(args, [full_model], checkpoint_name)
elif (
args.tensor_model_parallel_size > 1
and args.pipeline_model_parallel_size == 1
):
vision_state_dicts = split_vision_model(mgmodel.vision_model, args)
for tp_rank in range(args.tensor_model_parallel_size):
model_part = {}
checkpoint_name = get_checkpoint_name(args.save, 0, True, None, tp_rank)
print(f'tensor_parallel, save model to {checkpoint_name}')
for k, v in full_model.items():
if not isinstance(v, torch.Tensor):
target_v = v
elif 'vision_model' in k:
vision_part = vision_state_dicts[(tp_rank, 0)]
assert k in vision_part, f"Cannot find key {k} in vision model split!"
target_v = vision_part[k]
elif 'linear_qkv.weight' in k:
viewed = v.view(args.num_query_groups, -1, head_dim, args.hidden_size)
viewed = viewed[group_per_split*tp_rank : group_per_split*(tp_rank + 1)]
target_v = viewed.view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = v.view(args.num_query_groups, -1, head_dim)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1)
elif 'linear_proj' in k or 'linear_fc2' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg*tp_rank : seg*(tp_rank + 1)]
elif 'embedding' in k or 'output_layer' in k:
seg = v.shape[0] // args.tensor_model_parallel_size
target_v = v[seg*tp_rank : seg*(tp_rank + 1)]
elif 'linear_fc1' in k and 'norm' not in k:
viewed = v.view(-1, args.ffn_hidden_size, args.hidden_size)
seg = args.ffn_hidden_size // args.tensor_model_parallel_size
target_v = viewed[:, seg*tp_rank: seg*(tp_rank+1), :].reshape(-1, args.hidden_size)
else:
target_v = v
model_part[k] = target_v
save_state_dict(args, [model_part], checkpoint_name)
elif (
args.pipeline_model_parallel_size > 1
):
vision_state_dicts = split_vision_model(mgmodel.vision_model, args)
ltog, _ = build_layer_id_mapping(args)
for tp_rank in range(args.tensor_model_parallel_size):
for pp_rank in range(args.pipeline_model_parallel_size):
model_chunk = []
checkpoint_name = get_checkpoint_name(args.save, 0, True, True, tp_rank, pp_rank)
print(f'tensor_parallel & pipeline_parallel, save model to {checkpoint_name}')
for vpp_id in range(vpp_size):
layers_to_copy = {}
local_id = 0
while (pp_rank, vpp_id, local_id) in ltog:
gloabl_layer_id = ltog[(pp_rank, vpp_id, local_id)]
layers_to_copy[gloabl_layer_id] = local_id
local_id += 1
model_part = {}
for k, v in full_model.items():
if check_layer(layers_to_copy, k):
pattern = re.compile(r'\d+')
res = pattern.findall(k)
k = re.sub(r"decoder.layers.\d+", f"decoder.layers.{layers_to_copy[int(res[0])]}", k)
elif not ("word_embeddings" in k or "output_layer" in k or "final_layernorm" in k or 'vision_model' in k):
continue
if 'vision_model' in k:
if pp_rank > 0 or vpp_id > 0:
# NOTE: The vision model will only be attached to the first model_part of pp stage 0
continue
vision_part = vision_state_dicts[(tp_rank, 0)]
assert k in vision_part, f"Cannot find key {k} in vision model split!"
target_v = vision_part[k]
elif not isinstance(v, torch.Tensor):
target_v = v
elif 'linear_qkv.weight' in k:
viewed = v.view(args.num_query_groups, -1, head_dim, args.hidden_size)
viewed = viewed[group_per_split*tp_rank : group_per_split*(tp_rank + 1)]
target_v = viewed.view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = v.view(args.num_query_groups, -1, head_dim)
viewed = viewed[group_per_split * tp_rank: group_per_split * (tp_rank + 1)]
target_v = viewed.view(-1)
elif 'linear_proj' in k or 'linear_fc2' in k:
seg = v.shape[1] // args.tensor_model_parallel_size
target_v = v[:, seg*tp_rank : seg*(tp_rank + 1)]
elif 'embedding' in k or 'output_layer' in k:
seg = v.shape[0] // args.tensor_model_parallel_size
target_v = v[seg*tp_rank : seg*(tp_rank + 1)]
elif 'linear_fc1' in k and 'norm' not in k:
viewed = v.view(-1, args.ffn_hidden_size, args.hidden_size)
seg = args.ffn_hidden_size // args.tensor_model_parallel_size
target_v = viewed[:, seg*tp_rank: seg*(tp_rank+1), :].reshape(-1, args.hidden_size)
else:
target_v = v
if "word_embeddings" in k:
if pp_rank == 0 and vpp_id == 0:
model_part[k] = target_v
elif 'vision_model' not in k and ("output_layer" in k or "final_layernorm" in k):
if pp_rank == args.pipeline_model_parallel_size - 1 and vpp_id == vpp_size - 1:
model_part[k] = target_v
else:
model_part[k] = target_v
model_chunk.append(model_part)
save_state_dict(args, model_chunk, checkpoint_name, args.target_num_layers_per_virtual_pipeline_stage is not None)
else:
raise ValueError(f'Got invalid TP/PP: {args.tensor_model_parallel_size}/{args.pipeline_model_parallel_size}')
print(f'megatron model is save to {args.save}')