in src/transformers/models/glm4v/convert_glm4v_mgt_weights_to_hf.py [0:0]
def merge_tp_weights(model_path, output_path, vllm_config_path=None):
tp_size = 0
for item in Path(model_path).iterdir():
if item.is_dir():
match = re.match(r"mp_rank_(\d{2})", item.name)
if match:
tp = int(match.group(1))
tp_size = max(tp_size, tp + 1)
print(f"Detected tensor parallel degree TP={tp_size}")
if tp_size <= 1:
print("Model is already at TP=1, no need to merge")
return
print(f"Loading vLLM configuration file: {vllm_config_path}")
with open(vllm_config_path, "r") as f:
model_config = json.load(f)
num_layers = model_config.get("num_layers", 40)
vision_num_layers = model_config.get("vision_config", {}).get("num_hidden_layers", 24)
num_heads = model_config.get("num_attention_heads", 32)
num_kv_heads = model_config.get("num_query_groups", 2)
hidden_size = model_config.get("hidden_size", 4096)
head_dim = model_config.get("attention_dim", hidden_size // num_heads)
print(
f"Model parameters: num_layers={num_layers}, vision_num_layers={vision_num_layers}, "
f"num_heads={num_heads}, multi_query_group_num={num_kv_heads}, hidden_size={hidden_size}"
)
weights = []
for tp_rank in range(tp_size):
print(f"Loading TP shard {tp_rank}...")
weight_path = Path(model_path) / f"mp_rank_{tp_rank:02d}" / "model_optim_rng.pt"
sd = torch.load(weight_path, map_location="cpu", pickle_module=pickle)
for k in list(sd.keys()):
if "_extra_state" in k or "dummy_parameter" in k:
sd.pop(k)
if "model" in sd:
weights.append(sd["model"])
else:
raise ValueError(f"'model' key not found in {weight_path}")
if not weights:
raise ValueError("No valid weight files found")
print("Merging tensor parallel weights...")
original_pp_enabled = os.path.exists(Path(model_path) / "mp_rank_00_000")
original_tp, original_pp = tp_size, 1
target_tp = 1
print(f"TP and PP INFO: original_tp: {original_tp}, original_pp:{original_pp}, target_tp: {target_tp}")
mgt_sd = [
[
torch.load(
Path(model_path)
/ (f"mp_rank_{j:02d}_{i:03d}" if original_pp_enabled else f"mp_rank_{j:02d}")
/ "model_optim_rng.pt",
map_location="cpu",
pickle_module=pickle,
)
for j in range(original_tp)
]
for i in range(original_pp)
]
interleaved_qkv = False
multi_query_attention = True
num_attention_heads = num_heads
multi_query_group_num = num_kv_heads
attention_dim = head_dim
complete_state_dict = {}
keys = ["model"]
rank = 0
# LLM
for pp in range(original_pp):
layer_i = 0
mgt_encoder_tp_0 = dict_access_multi(mgt_sd[pp][rank], keys)
while f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight" in mgt_encoder_tp_0:
complete_state_dict.update(
{
f"model.language_model.layers.{layer_i}.input_layernorm.weight": mgt_encoder_tp_0[
f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight"
],
f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": mgt_encoder_tp_0[
f"decoder.layers.{layer_i}.mlp.linear_fc1.layer_norm_weight"
],
f"model.language_model.layers.{layer_i}.post_self_attn_layernorm.weight": mgt_encoder_tp_0[
f"decoder.layers.{layer_i}.post_self_attn_layernorm.weight"
],
f"model.language_model.layers.{layer_i}.post_mlp_layernorm.weight": mgt_encoder_tp_0[
f"decoder.layers.{layer_i}.post_mlp_layernorm.weight"
],
}
)
q, k, v = merge_tensors(
tp_sd=mgt_sd[pp],
keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_qkv.weight"],
original_tp=original_tp,
target_tp=target_tp,
current_tp=0,
merge_fn=lambda sd_list: merge_qkv(
sd_list,
original_tp,
num_attention_heads,
multi_query_group_num,
attention_dim,
multi_query_attention,
interleaved_qkv,
),
)
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight"] = q.clone()
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight"] = k.clone()
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight"] = v.clone()
if f"decoder.layers.{layer_i}.self_attention.linear_qkv.bias" in mgt_encoder_tp_0:
q_bias, k_bias, v_bias = merge_tensors(
tp_sd=mgt_sd[pp],
keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_qkv.bias"],
original_tp=original_tp,
target_tp=target_tp,
current_tp=0,
merge_fn=lambda sd_list: merge_qkv(
sd_list,
original_tp,
num_attention_heads,
multi_query_group_num,
attention_dim,
multi_query_attention,
interleaved_qkv,
),
)
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.bias"] = q_bias.clone()
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.bias"] = k_bias.clone()
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.bias"] = v_bias.clone()
o_proj = merge_tensors(
tp_sd=mgt_sd[pp],
keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_proj.weight"],
original_tp=original_tp,
target_tp=target_tp,
current_tp=0,
slice_dim=1,
)
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight"] = o_proj.clone()
# MLP - Use gate_up_proj
complete_state_dict[f"model.language_model.layers.{layer_i}.mlp.gate_up_proj.weight"] = merge_tensors(
tp_sd=mgt_sd[pp],
keys=keys + [f"decoder.layers.{layer_i}.mlp.linear_fc1.weight"],
original_tp=original_tp,
target_tp=target_tp,
current_tp=0,
merge_fn=merge_glu,
).clone()
complete_state_dict[f"model.language_model.layers.{layer_i}.mlp.down_proj.weight"] = merge_tensors(
tp_sd=mgt_sd[pp],
keys=keys + [f"decoder.layers.{layer_i}.mlp.linear_fc2.weight"],
original_tp=original_tp,
target_tp=target_tp,
current_tp=0,
slice_dim=1,
)
layer_i += 1
# Embedd Model, LM Head, and Norm
embed_tokens = merge_tensors(
tp_sd=mgt_sd[0],
keys=["model", "embedding.word_embeddings.weight"],
original_tp=original_tp,
target_tp=target_tp,
current_tp=0,
slice_dim=0,
)
complete_state_dict["model.language_model.embed_tokens.weight"] = embed_tokens.clone()
lm_head = merge_tensors(
tp_sd=mgt_sd[-1],
keys=["model", "output_layer.weight"],
original_tp=original_tp,
target_tp=target_tp,
current_tp=0,
slice_dim=0,
)
complete_state_dict["lm_head.weight"] = lm_head.clone()
complete_state_dict["model.language_model.norm.weight"] = mgt_sd[-1][rank]["model"][
"decoder.final_layernorm.weight"
].clone()
mgt_encoder_tp_0 = dict_access_multi(mgt_sd[0][0], keys)
# VLM
for layer_i in range(vision_num_layers):
complete_state_dict[f"model.visual.blocks.{layer_i}.norm1.weight"] = mgt_encoder_tp_0[
f"vision_model.transformer.layers.{layer_i}.input_layernorm.weight"
]
complete_state_dict[f"model.visual.blocks.{layer_i}.norm2.weight"] = mgt_encoder_tp_0[
f"vision_model.transformer.layers.{layer_i}.pre_mlp_layernorm.weight"
]
qkv_weight = merge_tensors_vit(
tp_sd=mgt_sd[0],
keys=keys + [f"vision_model.transformer.layers.{layer_i}.self_attention.linear_qkv.weight"],
original_tp=original_tp,
target_tp=target_tp,
merge_fn=merge_qkv_vit,
)
complete_state_dict[f"model.visual.blocks.{layer_i}.attn.qkv.weight"] = qkv_weight.clone()
proj_weight = merge_tensors_vit(
tp_sd=mgt_sd[0],
keys=keys + [f"vision_model.transformer.layers.{layer_i}.self_attention.linear_proj.weight"],
original_tp=original_tp,
target_tp=target_tp,
slice_dim=1,
)
complete_state_dict[f"model.visual.blocks.{layer_i}.attn.proj.weight"] = proj_weight.clone()
gate_proj_weight, up_proj_weight = merge_tensors_vit(
tp_sd=mgt_sd[0],
keys=keys + [f"vision_model.transformer.layers.{layer_i}.mlp.linear_fc1.weight"],
original_tp=original_tp,
target_tp=target_tp,
merge_fn=lambda sd_list, original_tp: merge_glu_vit(sd_list, original_tp),
)
complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.gate_proj.weight"] = gate_proj_weight.clone()
complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.up_proj.weight"] = up_proj_weight.clone()
down_proj_weight = merge_tensors_vit(
tp_sd=mgt_sd[0],
keys=keys + [f"vision_model.transformer.layers.{layer_i}.mlp.linear_fc2.weight"],
original_tp=original_tp,
target_tp=target_tp,
slice_dim=1,
)
complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.down_proj.weight"] = down_proj_weight.clone()
complete_state_dict["model.visual.downsample.weight"] = (
mgt_sd[0][0]["model"]["vision_model.downsample.weight"].clone().contiguous()
)
complete_state_dict["model.visual.downsample.bias"] = (
mgt_sd[0][0]["model"]["vision_model.downsample.bias"].clone().contiguous()
)
# Merger
gate_proj, up_proj = merge_tensors_vit(
tp_sd=mgt_sd[0],
keys=keys + ["vision_projection.encoder.linear_fc1.weight"],
original_tp=original_tp,
target_tp=target_tp,
merge_fn=merge_glu_vit,
)
down_proj = merge_tensors_vit(
tp_sd=mgt_sd[0],
keys=keys + ["vision_projection.encoder.linear_fc2.weight"],
original_tp=original_tp,
target_tp=target_tp,
slice_dim=1,
)
proj = merge_tensors_vit(
tp_sd=mgt_sd[0],
keys=keys + ["vision_projection.encoder.linear_fc_extra.weight"],
original_tp=original_tp,
target_tp=target_tp,
slice_dim=0,
)
complete_state_dict["model.visual.merger.gate_proj.weight"] = gate_proj.clone().contiguous()
complete_state_dict["model.visual.merger.up_proj.weight"] = up_proj.clone().contiguous()
complete_state_dict["model.visual.merger.down_proj.weight"] = down_proj.clone().contiguous()
complete_state_dict["model.visual.merger.proj.weight"] = proj.clone().contiguous()
complete_state_dict["model.visual.merger.post_projection_norm.weight"] = (
mgt_sd[0][0]["model"]["vision_projection.encoder.layer_norm.weight"].clone().contiguous()
)
complete_state_dict["model.visual.merger.post_projection_norm.bias"] = (
mgt_sd[0][0]["model"]["vision_projection.encoder.layer_norm.bias"].clone().contiguous()
)
complete_state_dict["model.visual.embeddings.position_embedding.weight"] = (
mgt_sd[0][0]["model"]["vision_model.position_embeddings.weight"].clone().contiguous()
)
complete_state_dict["model.visual.patch_embed.proj.weight"] = (
mgt_sd[0][0]["model"]["vision_model.conv3d.weight"].clone().contiguous()
)
complete_state_dict["model.visual.patch_embed.proj.bias"] = (
mgt_sd[0][0]["model"]["vision_model.conv3d.bias"].clone().contiguous()
)
# Check for additional vision model norm layers mentioned in the expected output
if "vision_model.post_conv_layernorm.weight" in mgt_encoder_tp_0:
complete_state_dict["model.visual.post_conv_layernorm.weight"] = (
mgt_sd[0][0]["model"]["vision_model.post_conv_layernorm.weight"].clone().contiguous()
)
if "vision_model.post_layernorm.weight" in mgt_encoder_tp_0:
complete_state_dict["model.visual.post_layernorm.weight"] = (
mgt_sd[0][0]["model"]["vision_model.post_layernorm.weight"].clone().contiguous()
)
print(f"Total keys in state dict: {len(complete_state_dict)}")
for key, value in complete_state_dict.items():
if isinstance(value, torch.Tensor):
complete_state_dict[key] = value.to(torch.bfloat16)
print("Converted all tensors to bfloat16")
# Save Model weight
save_sharded_model(
complete_state_dict,
output_path=output_path,
max_shard_size_gb=5,
num_layers=num_layers,
vision_num_layers=vision_num_layers,
)
hf_config = {
"architectures": ["Glm4vForConditionalGeneration"],
"model_type": "glm4v",
"attention_bias": model_config.get("add_qkv_bias", True),
"attention_dropout": 0.0,
"pad_token_id": model_config.get("pad_token_id", 151329),
"eos_token_id": model_config.get("eos_token_id", [151329, 151336, 151338]),
"image_start_token_id": model_config.get("image_start_token_id", 151339),
"image_end_token_id": model_config.get("image_end_token_id", 151340),
"video_start_token_id": model_config.get("video_start_token_id", 151341),
"video_end_token_id": model_config.get("video_end_token_id", 151342),
"image_token_id": model_config.get("image_token_id", 151343),
"video_token_id": model_config.get("video_token_id", 151344),
"hidden_act": model_config.get("hidden_act", "silu"),
"hidden_size": model_config.get("hidden_size", 4096),
"initializer_range": 0.02,
"intermediate_size": model_config.get("ffn_hidden_size", 13696),
"max_position_embeddings": model_config.get("seq_length", 32768),
"num_attention_heads": model_config.get("num_attention_heads", 32),
"num_hidden_layers": model_config.get("num_layers", 40),
"num_key_value_heads": model_config.get("multi_query_group_num", 2),
"rms_norm_eps": model_config.get("layernorm_epsilon", 1e-05),
"rope_theta": model_config.get("rotary_base", 10000.0),
"tie_word_embeddings": False,
"torch_dtype": model_config.get("torch_dtype", "bfloat16"),
"transformers_version": "4.53.0dev",
"use_cache": model_config.get("use_cache", True),
"vocab_size": model_config.get("vocab_size", 151552),
"partial_rotary_factor": 0.5,
}
if "vision_config" in model_config:
vision_config = {
"hidden_size": model_config["vision_config"].get("hidden_size", 1536),
"depth": model_config["vision_config"].get("num_layers", 24),
"num_heads": model_config["vision_config"].get("num_attention_heads", 12),
"attention_bias": model_config["vision_config"].get("attention_bias", False),
"intermediate_size": model_config.get("ffn_hidden_size", 13696),
"hidden_act": model_config["vision_config"].get("hidden_act", "silu"),
"hidden_dropout_prob": model_config["vision_config"].get("hidden_dropout_prob", 0.0),
"initializer_range": 0.02,
"image_size": model_config["vision_config"].get("image_size", 336),
"patch_size": model_config["vision_config"].get("patch_size", 14),
"out_hidden_size": model_config.get("hidden_size", 4096),
"rms_norm_eps": model_config["vision_config"].get("layernorm_epsilon", 1e-05),
"spatial_merge_size": model_config["vision_config"].get("downsample_ratio", 2),
"temporal_patch_size": model_config["vision_config"].get("t_patch", 2),
}
hf_config["vision_config"] = vision_config
if "rope_scaling" in model_config:
hf_config["rope_scaling"] = model_config["rope_scaling"]
config_path = os.path.join(output_path, "config.json")
with open(config_path, "w") as f:
json.dump(hf_config, f, indent=2)
print(f"Conversion complete! Model saved to {output_path}")