def convert_llama_state_dict_from_megatron_to_vllm()

in chatlearn/utils/vllm_utils.py [0:0]


def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version=None):
    """Convert NVIDIA Megatron-LM state_dict to vLLM llama state_dict.

        Args:
            args (argparse.Namespace): the arguments to the script
    """
    assert qwen_version is None, f"Expect qwen_version is None for Llama, while {qwen_version}"
    tp_rank = mpu.get_tensor_model_parallel_rank()
    pp_rank = get_pipeline_model_parallel_rank()

    state_dict = load_rank0_state_dict(args)

    megatron_args = state_dict.get("args", None)
    if "checkpoint_version" in state_dict.keys():
        checkpoint_version = state_dict["checkpoint_version"]
    else:
        checkpoint_version = 0.0
    if megatron_args is None:
        raise ValueError(
            "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints"
            " containing all the megatron arguments. This is because it loads all config related to model"
            " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to"
            " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron"
            " arguments to use this utility."
        )

    output_state_dict = {}

    tp_size = megatron_args.tensor_model_parallel_size
    pp_size = megatron_args.pipeline_model_parallel_size
    # The number of heads.
    heads = hf_config.num_attention_heads // tp_size
    # The hidden_size per head.
    hidden_size_per_head = hf_config.hidden_size // hf_config.num_attention_heads

    # The regex to extract layer names.
    layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # Convert.
    print("Start to convert...")
    prefix_name = "model" if is_vllm_v2() else "model.model"

    # Embeddings
    print("Converting embeddings")
    tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank)

    # Convert and store the position embeddings.
    position_embeddings = get_element_from_dict_by_path(
        tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight"
    )

    if position_embeddings:
        output_state_dict["transformer.position_embeddings.weight"] = position_embeddings.to(hf_config.torch_dtype)

    # Convert and store the word embeddings.
    word_embeddings = get_element_from_dict_by_path(tp_state_dicts[tp_rank], "model.word_embeddings_for_head.weight")
    if isinstance(word_embeddings, dict):
        word_embeddings = get_element_from_dict_by_path(
            tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight"
        )
    if isinstance(word_embeddings, dict):
        assert not word_embeddings, \
            "weight name of word_embed expect 'model.word_embeddings_for_head.weight' \
            or 'model.language_model.embedding.word_embeddings.weight'."
    elif word_embeddings is not None:
        # After training with megatron, word_embeddings is stored differently
        word_embeddings = word_embeddings.to(hf_config.torch_dtype)
        output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings
        # Reset the vocab size
        hf_config.vocab_size = word_embeddings.shape[0]

    # Transformer Layers
    print("Converting transformer layers")
    if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]:
        start_layer_idx, _ = get_pp_indices(
            hf_config.num_hidden_layers,
            pp_rank,
            pp_size
        )
        layer_offset = start_layer_idx
    else:
        assert pp_size == 1, f"expect pipeline parallel size eq 1 for vllm {CURRENT_VLLM_VERSION}"
        layer_offset = hf_config.num_hidden_layers // pp_size * pp_rank
    # The transformer.
    path = (
        "model.language_model.transformer"
        if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys()
        else "model.language_model.encoder"
    )

    # Extract the layers.
    for key, val in get_element_from_dict_by_path(tp_state_dicts[tp_rank], path).items():
        # skip None value.
        # TODO(jiangle.jl): whether to process empty value.
        if val is None:
            continue
        # Match the name.
        m = layer_re.match(key)
        # Stop if that's not a layer
        if m is None:
            break
        # The index of the layer.
        layer_idx = int(m.group(1)) + layer_offset
        # The name of the operation.
        op_name = m.group(2)
        # Is it a weight or a bias?
        weight_or_bias = m.group(3)
        # The name of the layer.
        layer_name = f"{prefix_name}.layers.{layer_idx}"

        params = val.to(hf_config.torch_dtype)

        # For layernorm(s), simply store the layer norm.
        if (op_name.endswith("_norm") or op_name.endswith("_layernorm")) and weight_or_bias == 'weight':
            ln_name = "input_layernorm" if op_name.startswith("input") else "post_attention_layernorm"
            output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params

        # Transpose the QKV matrix.
        elif op_name in ["attention.query_key_value", "self_attention.query_key_value"] and weight_or_bias == "weight":
            input_shape = params.size()
            shape = (heads, hidden_size_per_head, 3) + input_shape[1:]
            division = reduce(operator.mul, shape, 1)
            num_elements = params.numel()
            if num_elements != division:
                # model with gqa dont need to fix qkv ordering.
                output_state_dict[layer_name + ".self_attn.qkv_proj.weight"] = params
            else:
                out_val = fix_qwen_query_key_value_ordering(
                    params, checkpoint_version, 3, heads, hidden_size_per_head
                )
                # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
                out_val = out_val.contiguous()
                # Store.
                output_state_dict[layer_name + ".self_attn.qkv_proj.weight"] = out_val

        # Transpose the bias.
        elif op_name in ["attention.query_key_value", "self_attention.query_key_value"] and weight_or_bias == "bias":
            out_val = fix_qwen_query_key_value_ordering(
                params, checkpoint_version, 3, heads, hidden_size_per_head
            )
            # Store. No change of shape.
            output_state_dict[layer_name + ".self_attn.qkv_proj.bias"] = out_val


        # Transpose the weights.
        elif weight_or_bias == "weight":
            out_name = megatron_to_transformers[op_name]
            output_state_dict[layer_name + out_name + "weight"] = params

        # Copy the bias.
        # Ignore them
        elif weight_or_bias == "bias":
            pass

        # Copy the Rotary Embedding
        else:
            out_name = megatron_to_transformers[op_name]
            output_state_dict[layer_name + out_name] = params

    # The final layernorm.
    params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], str(path))
    if "final_norm.weight" in params or "final_layernorm.weight" in params:
        print("Converting final layernorm")
        final_norm_weight =  params["final_norm.weight"] if "final_norm.weight" in params else params["final_layernorm.weight"]
        output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype)

    # For LM head, transformers' wants the matrix to weight embeddings.
    params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], 'model.language_model.output_layer.weight')
    if isinstance(params, dict):
        assert not params, "weight name of lm_head expect 'model.language_model.output_layer.weight'."
    elif params is not None:
        print("Converting LM head")
        output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype)

    # It should be done!
    print("Conversion from Megatron-LM to Transformers is done!")

    return output_state_dict