def merge_tp_weights()

in src/transformers/models/glm4v/convert_glm4v_mgt_weights_to_hf.py [0:0]


def merge_tp_weights(model_path, output_path, vllm_config_path=None):
    tp_size = 0
    for item in Path(model_path).iterdir():
        if item.is_dir():
            match = re.match(r"mp_rank_(\d{2})", item.name)
            if match:
                tp = int(match.group(1))
                tp_size = max(tp_size, tp + 1)

    print(f"Detected tensor parallel degree TP={tp_size}")

    if tp_size <= 1:
        print("Model is already at TP=1, no need to merge")
        return

    print(f"Loading vLLM configuration file: {vllm_config_path}")
    with open(vllm_config_path, "r") as f:
        model_config = json.load(f)
        num_layers = model_config.get("num_layers", 40)
        vision_num_layers = model_config.get("vision_config", {}).get("num_hidden_layers", 24)
        num_heads = model_config.get("num_attention_heads", 32)
        num_kv_heads = model_config.get("num_query_groups", 2)
        hidden_size = model_config.get("hidden_size", 4096)
        head_dim = model_config.get("attention_dim", hidden_size // num_heads)

    print(
        f"Model parameters: num_layers={num_layers}, vision_num_layers={vision_num_layers}, "
        f"num_heads={num_heads}, multi_query_group_num={num_kv_heads}, hidden_size={hidden_size}"
    )

    weights = []
    for tp_rank in range(tp_size):
        print(f"Loading TP shard {tp_rank}...")
        weight_path = Path(model_path) / f"mp_rank_{tp_rank:02d}" / "model_optim_rng.pt"
        sd = torch.load(weight_path, map_location="cpu", pickle_module=pickle)

        for k in list(sd.keys()):
            if "_extra_state" in k or "dummy_parameter" in k:
                sd.pop(k)

        if "model" in sd:
            weights.append(sd["model"])
        else:
            raise ValueError(f"'model' key not found in {weight_path}")

    if not weights:
        raise ValueError("No valid weight files found")

    print("Merging tensor parallel weights...")
    original_pp_enabled = os.path.exists(Path(model_path) / "mp_rank_00_000")
    original_tp, original_pp = tp_size, 1
    target_tp = 1
    print(f"TP and PP INFO: original_tp: {original_tp}, original_pp:{original_pp}, target_tp: {target_tp}")
    mgt_sd = [
        [
            torch.load(
                Path(model_path)
                / (f"mp_rank_{j:02d}_{i:03d}" if original_pp_enabled else f"mp_rank_{j:02d}")
                / "model_optim_rng.pt",
                map_location="cpu",
                pickle_module=pickle,
            )
            for j in range(original_tp)
        ]
        for i in range(original_pp)
    ]

    interleaved_qkv = False
    multi_query_attention = True
    num_attention_heads = num_heads
    multi_query_group_num = num_kv_heads
    attention_dim = head_dim
    complete_state_dict = {}
    keys = ["model"]
    rank = 0

    # LLM
    for pp in range(original_pp):
        layer_i = 0
        mgt_encoder_tp_0 = dict_access_multi(mgt_sd[pp][rank], keys)

        while f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight" in mgt_encoder_tp_0:
            complete_state_dict.update(
                {
                    f"model.language_model.layers.{layer_i}.input_layernorm.weight": mgt_encoder_tp_0[
                        f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight"
                    ],
                    f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": mgt_encoder_tp_0[
                        f"decoder.layers.{layer_i}.mlp.linear_fc1.layer_norm_weight"
                    ],
                    f"model.language_model.layers.{layer_i}.post_self_attn_layernorm.weight": mgt_encoder_tp_0[
                        f"decoder.layers.{layer_i}.post_self_attn_layernorm.weight"
                    ],
                    f"model.language_model.layers.{layer_i}.post_mlp_layernorm.weight": mgt_encoder_tp_0[
                        f"decoder.layers.{layer_i}.post_mlp_layernorm.weight"
                    ],
                }
            )

            q, k, v = merge_tensors(
                tp_sd=mgt_sd[pp],
                keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_qkv.weight"],
                original_tp=original_tp,
                target_tp=target_tp,
                current_tp=0,
                merge_fn=lambda sd_list: merge_qkv(
                    sd_list,
                    original_tp,
                    num_attention_heads,
                    multi_query_group_num,
                    attention_dim,
                    multi_query_attention,
                    interleaved_qkv,
                ),
            )

            complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight"] = q.clone()
            complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight"] = k.clone()
            complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight"] = v.clone()

            if f"decoder.layers.{layer_i}.self_attention.linear_qkv.bias" in mgt_encoder_tp_0:
                q_bias, k_bias, v_bias = merge_tensors(
                    tp_sd=mgt_sd[pp],
                    keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_qkv.bias"],
                    original_tp=original_tp,
                    target_tp=target_tp,
                    current_tp=0,
                    merge_fn=lambda sd_list: merge_qkv(
                        sd_list,
                        original_tp,
                        num_attention_heads,
                        multi_query_group_num,
                        attention_dim,
                        multi_query_attention,
                        interleaved_qkv,
                    ),
                )
                complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.bias"] = q_bias.clone()
                complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.bias"] = k_bias.clone()
                complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.bias"] = v_bias.clone()

            o_proj = merge_tensors(
                tp_sd=mgt_sd[pp],
                keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_proj.weight"],
                original_tp=original_tp,
                target_tp=target_tp,
                current_tp=0,
                slice_dim=1,
            )
            complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight"] = o_proj.clone()

            # MLP - Use gate_up_proj
            complete_state_dict[f"model.language_model.layers.{layer_i}.mlp.gate_up_proj.weight"] = merge_tensors(
                tp_sd=mgt_sd[pp],
                keys=keys + [f"decoder.layers.{layer_i}.mlp.linear_fc1.weight"],
                original_tp=original_tp,
                target_tp=target_tp,
                current_tp=0,
                merge_fn=merge_glu,
            ).clone()
            complete_state_dict[f"model.language_model.layers.{layer_i}.mlp.down_proj.weight"] = merge_tensors(
                tp_sd=mgt_sd[pp],
                keys=keys + [f"decoder.layers.{layer_i}.mlp.linear_fc2.weight"],
                original_tp=original_tp,
                target_tp=target_tp,
                current_tp=0,
                slice_dim=1,
            )
            layer_i += 1

    # Embedd Model, LM Head, and Norm
    embed_tokens = merge_tensors(
        tp_sd=mgt_sd[0],
        keys=["model", "embedding.word_embeddings.weight"],
        original_tp=original_tp,
        target_tp=target_tp,
        current_tp=0,
        slice_dim=0,
    )
    complete_state_dict["model.language_model.embed_tokens.weight"] = embed_tokens.clone()
    lm_head = merge_tensors(
        tp_sd=mgt_sd[-1],
        keys=["model", "output_layer.weight"],
        original_tp=original_tp,
        target_tp=target_tp,
        current_tp=0,
        slice_dim=0,
    )
    complete_state_dict["lm_head.weight"] = lm_head.clone()
    complete_state_dict["model.language_model.norm.weight"] = mgt_sd[-1][rank]["model"][
        "decoder.final_layernorm.weight"
    ].clone()
    mgt_encoder_tp_0 = dict_access_multi(mgt_sd[0][0], keys)

    # VLM
    for layer_i in range(vision_num_layers):
        complete_state_dict[f"model.visual.blocks.{layer_i}.norm1.weight"] = mgt_encoder_tp_0[
            f"vision_model.transformer.layers.{layer_i}.input_layernorm.weight"
        ]
        complete_state_dict[f"model.visual.blocks.{layer_i}.norm2.weight"] = mgt_encoder_tp_0[
            f"vision_model.transformer.layers.{layer_i}.pre_mlp_layernorm.weight"
        ]

        qkv_weight = merge_tensors_vit(
            tp_sd=mgt_sd[0],
            keys=keys + [f"vision_model.transformer.layers.{layer_i}.self_attention.linear_qkv.weight"],
            original_tp=original_tp,
            target_tp=target_tp,
            merge_fn=merge_qkv_vit,
        )
        complete_state_dict[f"model.visual.blocks.{layer_i}.attn.qkv.weight"] = qkv_weight.clone()

        proj_weight = merge_tensors_vit(
            tp_sd=mgt_sd[0],
            keys=keys + [f"vision_model.transformer.layers.{layer_i}.self_attention.linear_proj.weight"],
            original_tp=original_tp,
            target_tp=target_tp,
            slice_dim=1,
        )
        complete_state_dict[f"model.visual.blocks.{layer_i}.attn.proj.weight"] = proj_weight.clone()

        gate_proj_weight, up_proj_weight = merge_tensors_vit(
            tp_sd=mgt_sd[0],
            keys=keys + [f"vision_model.transformer.layers.{layer_i}.mlp.linear_fc1.weight"],
            original_tp=original_tp,
            target_tp=target_tp,
            merge_fn=lambda sd_list, original_tp: merge_glu_vit(sd_list, original_tp),
        )
        complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.gate_proj.weight"] = gate_proj_weight.clone()
        complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.up_proj.weight"] = up_proj_weight.clone()

        down_proj_weight = merge_tensors_vit(
            tp_sd=mgt_sd[0],
            keys=keys + [f"vision_model.transformer.layers.{layer_i}.mlp.linear_fc2.weight"],
            original_tp=original_tp,
            target_tp=target_tp,
            slice_dim=1,
        )
        complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.down_proj.weight"] = down_proj_weight.clone()

    complete_state_dict["model.visual.downsample.weight"] = (
        mgt_sd[0][0]["model"]["vision_model.downsample.weight"].clone().contiguous()
    )
    complete_state_dict["model.visual.downsample.bias"] = (
        mgt_sd[0][0]["model"]["vision_model.downsample.bias"].clone().contiguous()
    )

    # Merger
    gate_proj, up_proj = merge_tensors_vit(
        tp_sd=mgt_sd[0],
        keys=keys + ["vision_projection.encoder.linear_fc1.weight"],
        original_tp=original_tp,
        target_tp=target_tp,
        merge_fn=merge_glu_vit,
    )

    down_proj = merge_tensors_vit(
        tp_sd=mgt_sd[0],
        keys=keys + ["vision_projection.encoder.linear_fc2.weight"],
        original_tp=original_tp,
        target_tp=target_tp,
        slice_dim=1,
    )
    proj = merge_tensors_vit(
        tp_sd=mgt_sd[0],
        keys=keys + ["vision_projection.encoder.linear_fc_extra.weight"],
        original_tp=original_tp,
        target_tp=target_tp,
        slice_dim=0,
    )

    complete_state_dict["model.visual.merger.gate_proj.weight"] = gate_proj.clone().contiguous()
    complete_state_dict["model.visual.merger.up_proj.weight"] = up_proj.clone().contiguous()
    complete_state_dict["model.visual.merger.down_proj.weight"] = down_proj.clone().contiguous()
    complete_state_dict["model.visual.merger.proj.weight"] = proj.clone().contiguous()

    complete_state_dict["model.visual.merger.post_projection_norm.weight"] = (
        mgt_sd[0][0]["model"]["vision_projection.encoder.layer_norm.weight"].clone().contiguous()
    )
    complete_state_dict["model.visual.merger.post_projection_norm.bias"] = (
        mgt_sd[0][0]["model"]["vision_projection.encoder.layer_norm.bias"].clone().contiguous()
    )
    complete_state_dict["model.visual.embeddings.position_embedding.weight"] = (
        mgt_sd[0][0]["model"]["vision_model.position_embeddings.weight"].clone().contiguous()
    )
    complete_state_dict["model.visual.patch_embed.proj.weight"] = (
        mgt_sd[0][0]["model"]["vision_model.conv3d.weight"].clone().contiguous()
    )
    complete_state_dict["model.visual.patch_embed.proj.bias"] = (
        mgt_sd[0][0]["model"]["vision_model.conv3d.bias"].clone().contiguous()
    )

    # Check for additional vision model norm layers mentioned in the expected output
    if "vision_model.post_conv_layernorm.weight" in mgt_encoder_tp_0:
        complete_state_dict["model.visual.post_conv_layernorm.weight"] = (
            mgt_sd[0][0]["model"]["vision_model.post_conv_layernorm.weight"].clone().contiguous()
        )

    if "vision_model.post_layernorm.weight" in mgt_encoder_tp_0:
        complete_state_dict["model.visual.post_layernorm.weight"] = (
            mgt_sd[0][0]["model"]["vision_model.post_layernorm.weight"].clone().contiguous()
        )

    print(f"Total keys in state dict: {len(complete_state_dict)}")

    for key, value in complete_state_dict.items():
        if isinstance(value, torch.Tensor):
            complete_state_dict[key] = value.to(torch.bfloat16)
    print("Converted all tensors to bfloat16")
    # Save Model weight
    save_sharded_model(
        complete_state_dict,
        output_path=output_path,
        max_shard_size_gb=5,
        num_layers=num_layers,
        vision_num_layers=vision_num_layers,
    )

    hf_config = {
        "architectures": ["Glm4vForConditionalGeneration"],
        "model_type": "glm4v",
        "attention_bias": model_config.get("add_qkv_bias", True),
        "attention_dropout": 0.0,
        "pad_token_id": model_config.get("pad_token_id", 151329),
        "eos_token_id": model_config.get("eos_token_id", [151329, 151336, 151338]),
        "image_start_token_id": model_config.get("image_start_token_id", 151339),
        "image_end_token_id": model_config.get("image_end_token_id", 151340),
        "video_start_token_id": model_config.get("video_start_token_id", 151341),
        "video_end_token_id": model_config.get("video_end_token_id", 151342),
        "image_token_id": model_config.get("image_token_id", 151343),
        "video_token_id": model_config.get("video_token_id", 151344),
        "hidden_act": model_config.get("hidden_act", "silu"),
        "hidden_size": model_config.get("hidden_size", 4096),
        "initializer_range": 0.02,
        "intermediate_size": model_config.get("ffn_hidden_size", 13696),
        "max_position_embeddings": model_config.get("seq_length", 32768),
        "num_attention_heads": model_config.get("num_attention_heads", 32),
        "num_hidden_layers": model_config.get("num_layers", 40),
        "num_key_value_heads": model_config.get("multi_query_group_num", 2),
        "rms_norm_eps": model_config.get("layernorm_epsilon", 1e-05),
        "rope_theta": model_config.get("rotary_base", 10000.0),
        "tie_word_embeddings": False,
        "torch_dtype": model_config.get("torch_dtype", "bfloat16"),
        "transformers_version": "4.53.0dev",
        "use_cache": model_config.get("use_cache", True),
        "vocab_size": model_config.get("vocab_size", 151552),
        "partial_rotary_factor": 0.5,
    }

    if "vision_config" in model_config:
        vision_config = {
            "hidden_size": model_config["vision_config"].get("hidden_size", 1536),
            "depth": model_config["vision_config"].get("num_layers", 24),
            "num_heads": model_config["vision_config"].get("num_attention_heads", 12),
            "attention_bias": model_config["vision_config"].get("attention_bias", False),
            "intermediate_size": model_config.get("ffn_hidden_size", 13696),
            "hidden_act": model_config["vision_config"].get("hidden_act", "silu"),
            "hidden_dropout_prob": model_config["vision_config"].get("hidden_dropout_prob", 0.0),
            "initializer_range": 0.02,
            "image_size": model_config["vision_config"].get("image_size", 336),
            "patch_size": model_config["vision_config"].get("patch_size", 14),
            "out_hidden_size": model_config.get("hidden_size", 4096),
            "rms_norm_eps": model_config["vision_config"].get("layernorm_epsilon", 1e-05),
            "spatial_merge_size": model_config["vision_config"].get("downsample_ratio", 2),
            "temporal_patch_size": model_config["vision_config"].get("t_patch", 2),
        }
        hf_config["vision_config"] = vision_config

    if "rope_scaling" in model_config:
        hf_config["rope_scaling"] = model_config["rope_scaling"]

    config_path = os.path.join(output_path, "config.json")
    with open(config_path, "w") as f:
        json.dump(hf_config, f, indent=2)

    print(f"Conversion complete! Model saved to {output_path}")