def convert_qwen_state_dict_from_megatron_to_vllm()

in chatlearn/utils/vllm_utils.py [0:0]


def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version=QwenVersion.v_1):
    # The converted output model.
    output_state_dict = {}

    # configuration for different versions of qwen
    if qwen_version == QwenVersion.v_1:
        prefix_name = "model.transformer."
        embed_name = "wte"
        layer_prefix = "h"
        final_norm = "ln_f"
        func_map = megatron_qwen_to_transformers
    elif qwen_version == QwenVersion.v_2:
        prefix_name = "model." if is_vllm_v2() else "model.model."
        embed_name = "embed_tokens"
        layer_prefix = "layers"
        final_norm = "norm"
        func_map = megatron_qwen2_to_transformers
    else:
        raise RuntimeError(f"Unsupported qwen version {qwen_version}, only 1.0 or 2.0 for now. while {qwen_version}.")

    tp_rank = mpu.get_tensor_model_parallel_rank()
    pp_rank = get_pipeline_model_parallel_rank()

    state_dict = load_rank0_state_dict(args)
    megatron_args = state_dict.get("args", None)
    if "checkpoint_version" in state_dict.keys():
        checkpoint_version = state_dict["checkpoint_version"]
    else:
        checkpoint_version = 0.0
    if megatron_args is None:
        raise ValueError(
            "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints"
            " containing all the megatron arguments. This is because it loads all config related to model"
            " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to"
            " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron"
            " arguments to use this utility."
        )

    tp_size = megatron_args.tensor_model_parallel_size
    pp_size = megatron_args.pipeline_model_parallel_size
    if hasattr(megatron_args, "moe_expert_model_parallel_size"):
        ep_size = megatron_args.moe_expert_model_parallel_size
        hep_size = tp_size * ep_size
    else:
        ep_size = 1
        hep_size = tp_size

    # The number of heads.
    heads = hf_config.num_attention_heads // tp_size
    # The hidden_size per head.
    hidden_size_per_head = hf_config.hidden_size // hf_config.num_attention_heads

    # The regex to extract layer names.
    layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # Convert.
    print("Start to convert...")

    # Embeddings
    print("Converting embeddings")
    tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank)

    # Convert and store the word embeddings.
    if pp_rank == 0 or (pp_rank == pp_size - 1 and not megatron_args.untie_embeddings_and_output_weights) or \
            (hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts):
        embed_state_dict = tp_state_dicts if pp_rank == 0 else get_megatron_sharded_states(args, tp_size, pp_size, 0)
        word_embeddings = get_element_from_dict_by_path(
            embed_state_dict[tp_rank], "model.language_model.embedding.word_embeddings.weight"
        )
        if isinstance(word_embeddings, dict):
            assert not word_embeddings, \
                "weight name of word_embed expect 'model.word_embeddings_for_head.weight' \
                or 'model.language_model.embedding.word_embeddings.weight'."
        elif word_embeddings is not None:
            # After training with megatron, word_embeddings is stored differently
            word_embeddings = word_embeddings.to(hf_config.torch_dtype)
            word_embeddings = word_embeddings[: hf_config.vocab_size, :]
            output_state_dict[f"{prefix_name}{embed_name}.weight"] = word_embeddings
            # Reset the vocab size
            hf_config.vocab_size = word_embeddings.shape[0]

    # Transformer Layers
    print("Converting transformer layers")
    if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]:
        start_layer_idx, _ = get_pp_indices(
            hf_config.num_hidden_layers,
            pp_rank,
            pp_size
        )
        layer_offset = start_layer_idx
    else:
        assert pp_size == 1, f"expect pipeline parallel size eq 1 for vllm {CURRENT_VLLM_VERSION}"
        layer_offset = hf_config.num_hidden_layers // pp_size * pp_rank

    # The transformer.
    path = (
        "model.language_model.transformer"
        if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys()
        else "model.language_model.encoder"
    )

    # Extract the layers.
    gate_up_proj = {}
    for key, val in get_element_from_dict_by_path(tp_state_dicts[tp_rank], path).items():
        # skip None value.
        # TODO(jiangle.jl): whether to process empty value.
        if val is None:
            continue
        # Match the name.
        m = layer_re.match(key)
        # Stop if that's not a layer
        if m is None:
            continue
        # The index of the layer.
        layer_idx = int(m.group(1)) + layer_offset
        # The name of the operation.
        op_name = m.group(2)
        # Is it a weight or a bias?
        weight_or_bias = m.group(3)
        # The name of the layer.
        layer_name = f"{prefix_name}{layer_prefix}.{layer_idx}"

        params = val.to(hf_config.torch_dtype)

        # For layernorm(s), simply store the layer norm.
        if op_name.endswith("layernorm"):

            if qwen_version == QwenVersion.v_1:
                if "attention." in op_name:
                    output_state_dict[
                        layer_name + ".attn.attention_layernorm." + weight_or_bias
                    ] = params
                if "mlp." in op_name:
                    output_state_dict[
                        layer_name + "." + op_name + "." + weight_or_bias
                    ] = params

            if op_name.startswith("input"):
                ln_name = "ln_1" if qwen_version == QwenVersion.v_1 else "input_layernorm"
                output_state_dict[
                    layer_name + "." + ln_name + "." + weight_or_bias
                ] = params
            elif op_name.startswith("post"):
                ln_name  = "ln_2" if qwen_version == QwenVersion.v_1 else "post_attention_layernorm"
                output_state_dict[
                    layer_name + "." + ln_name + "." + weight_or_bias
                ] = params
            elif qwen_version == QwenVersion.v_2:
                raise RuntimeError(f"unsupport layernorm {op_name}.")

        elif op_name == "self_attention.rotary_emb":
            output_state_dict[layer_name + ".attn.rotary_emb.inv_freq"] = params

        # Transpose the QKV matrix and the bias.
        elif op_name in ["attention.query_key_value", "self_attention.query_key_value"]:
            if qwen_version == QwenVersion.v_1:
                out_val = fix_qwen_query_key_value_ordering(
                    params, checkpoint_version, 3, heads, hidden_size_per_head
                )
                # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
                if len(list(out_val.shape)) > 1:
                    out_val = out_val.contiguous()
                # Store.
                output_state_dict[layer_name + f".attn.c_attn.{weight_or_bias}"] = out_val
            else:
                num_query_groups = megatron_args.num_query_groups if megatron_args.group_query_attention else megatron_args.num_attention_heads
                params = split_attn_state(params, heads, num_query_groups // tp_size, hidden_size_per_head, hf_config.hidden_size)
                # Store. No change of shape.
                output_state_dict[layer_name + f".self_attn.qkv_proj.{weight_or_bias}"] = params

        elif op_name in ["mlp.dense_h_to_4h"]:
            offset = params.shape[0] // 2
            w1 = params[:offset,:]
            w2 = params[offset:,:]
            out_name = func_map[op_name]
            out_name = layer_name + out_name + "weight"
            output_state_dict[out_name] = torch.cat([w2, w1], dim=0)

        elif op_name in ["mlp.w1", "mlp.w2"]:
            gate_up_proj[op_name] = params

            if len(gate_up_proj) == 2:
                gate_up_proj = [gate_up_proj["mlp.w2"], gate_up_proj["mlp.w1"]]
                out_name = func_map[op_name]
                gate_up_proj_name = layer_name + out_name + "weight"
                output_state_dict[gate_up_proj_name] = torch.cat(gate_up_proj, dim=0)
                gate_up_proj = {}

        elif op_name in ["mlp.shared_experts.dense_h_to_4h"]:
            out_name = func_map[op_name]
            gate_up_proj_name = layer_name + out_name + "weight"
            w1, w2 = params.chunk(2, dim=0)
            output_state_dict[gate_up_proj_name] = torch.cat([w2, w1], dim=0).contiguous()

        elif "mlp.experts" in op_name:
            # For w13_weight and w2_weight, each tp slice contains part of expert weights.
            # qwen w13_weight when tp = 4 (pp=1,ep=1):
            #       rank 0: [[0.1, 0.2], [0.3, 0.4]]
            #       rank 1: [[1.1, 1.2], [1.3, 1.4]]
            #       rank 2: [[2.1, 2.2], [2.3, 2.4]]
            #       rank 3: [[3.1, 3.2], [3.3, 3.4]]
            # vLLM w13_weight when tp = 4 (pp=1,ep=1):
            #       rank 0: [[0.1, 1.1], [2.1, 3.1]]
            #       rank 1: [[0.2, 1.2], [2.2, 3.2]]
            #       rank 2: [[0.3, 1.3], [2.3, 3.3]]
            #       rank 3: [[0.4, 1.4], [2.4, 3.4]]
            # w2_weight as well.
            out_name = func_map[op_name]
            moe_num_experts = megatron_args.moe_num_experts
            local_num_experts = moe_num_experts // hep_size
            if "dense_h_to_4h" in op_name:
                params_list = []
                for rank in range(tp_size):
                    if rank != tp_rank:
                        params = get_element_from_dict_by_path(tp_state_dicts[rank], path)[key]
                    params_list.append(params)

                val_list = []
                for params in params_list:
                    params = params.view((moe_num_experts, -1, hf_config.hidden_size)).contiguous()
                    params = params.reshape((local_num_experts * 2, -1, hf_config.hidden_size))
                    params = params.chunk(tp_size, dim=1)[tp_rank]
                    params = params.reshape(params.shape[0] // 2, -1, hf_config.hidden_size)
                    params_right, params_left = params.chunk(2, dim=1)
                    params = torch.cat([params_left, params_right], dim=1).contiguous()
                    val_list.append(params)
                val = torch.cat(val_list, dim=0).contiguous()
            elif "dense_4h_to_h" in op_name:
                params_list = []
                for rank in range(tp_size):
                    if rank != tp_rank:
                        params = get_element_from_dict_by_path(tp_state_dicts[rank], path)[key]
                    params = params.view((moe_num_experts, -1, hf_config.hidden_size)).contiguous()
                    params_list.append(params)
                val_list = []
                for params in params_list:
                    params = params.reshape((local_num_experts, -1, hf_config.hidden_size))
                    params = params.chunk(tp_size, dim=1)[tp_rank]
                    val_list.append(params)
                val = torch.cat(val_list, dim=0).transpose(1, 2).contiguous()
            else:
                raise RuntimeError(f"only support routed weight name 'dense_h_to_4h' or 'dense_4h_to_h' for qwen2_moe. while {op_name}.")
            output_state_dict[layer_name + out_name] = val

        # Transpose the weights.
        elif weight_or_bias == "weight":
            out_name = func_map[op_name]
            output_state_dict[layer_name + out_name + "weight"] = params

        # Copy the bias.
        elif weight_or_bias == "bias":
            out_name = func_map[op_name]
            output_state_dict[layer_name + out_name + "bias"] = params

    # The final layernorm.
    if hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts:
        final_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_size - 1)
        params = get_element_from_dict_by_path(final_state_dicts[tp_rank], str(path))
    else:
        params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], str(path))
    if "final_norm.weight" in params or "final_layernorm.weight" in params:
        final_norm_weight =  params["final_norm.weight"] if "final_norm.weight" in params else params["final_layernorm.weight"]
        output_state_dict[f"{prefix_name}{final_norm}.weight"] = final_norm_weight.to(hf_config.torch_dtype)
    if "final_norm.bias" in params or "final_layernorm.bias" in params:
        final_norm_bias =  params["final_norm.bias"] if "final_norm.bias" in params else params["final_layernorm.bias"]
        output_state_dict[f"{prefix_name}{final_norm}.bias"] = final_norm_bias.to(hf_config.torch_dtype)

    # For LM head, transformers' wants the matrix to weight embeddings.
    print("Converting LM head")
    lm_head_name = "lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"
    if megatron_args.untie_embeddings_and_output_weights:
        if hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts:
            params = get_element_from_dict_by_path(final_state_dicts[tp_rank], 'model.language_model.output_layer.weight')
        else:
            params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], 'model.language_model.output_layer.weight')
        if (isinstance(params, dict) and len(params.keys())) or (params is not None and not isinstance(params, dict)):
            output_state_dict[lm_head_name] = params.to(hf_config.torch_dtype)
    elif pp_rank == 0 or (pp_rank == pp_size - 1) or (hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts):
        output_state_dict[lm_head_name] = word_embeddings

    # It should be done!
    print("Conversion from Megatron-LM to Transformers is done!")

    return output_state_dict