def convert_llama_state_dict_from_mcore_to_vllm()

in chatlearn/utils/vllm_utils.py [0:0]


def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=None):
    """Convert NVIDIA Megatron-Core state_dict to vLLM llama state_dict.

        Args:
            args (argparse.Namespace): the arguments to the script
    """
    assert qwen_version is None, f"Expect qwen_version is None for Llama, while {qwen_version}"
    tp_rank = mpu.get_tensor_model_parallel_rank()
    pp_rank = get_pipeline_model_parallel_rank()
    assert pp_rank == 0, "pipeline parallelism for mcore inference not supported for now."

    state_dict = load_rank0_state_dict(args)

    megatron_args = state_dict.get("args", None)
    if "checkpoint_version" in state_dict.keys():
        checkpoint_version = state_dict["checkpoint_version"]
    else:
        checkpoint_version = 0.0
    if megatron_args is None:
        raise ValueError(
            "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints"
            " containing all the megatron arguments. This is because it loads all config related to model"
            " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to"
            " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron"
            " arguments to use this utility."
        )

    output_state_dict = {}

    tp_size = megatron_args.tensor_model_parallel_size
    pp_size = megatron_args.pipeline_model_parallel_size
    assert pp_size == 1, "pipeline parallelism for mcore inference not supported for now."
    # The number of heads.
    heads = hf_config.num_attention_heads // tp_size
    # The hidden_size per head.
    hidden_size_per_head = hf_config.hidden_size // hf_config.num_attention_heads

    # The regex to extract layer names.
    layer_re = re.compile(r"decoder.layers\.(\d+)\.([a-z0-9_.]+)[\._]([a-z]+)")

    # Convert.
    print("Start to convert...")
    prefix_name = "model" if is_vllm_v2() else "model.model"

    # Embeddings
    print("Converting embeddings")
    tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0)
    # tp_state_dicts: list of state dict for each tp rank
    # tp_state_dicts[0]: a state dict for tp rank 0
    # |-keys: dict_keys(['args', 'checkpoint_version', 'iteration', 'model', ...])
    # |-tp_state_dicts[0]['model']
    #    |-keys: ['embedding.word_embeddings.weight',
    #             'decoder.layers.0.self_attention.core_attention.fused_attention._extra_state',
    #             'decoder.layers.0.self_attention.linear_proj.weight',
    #             'decoder.layers.0.self_attention.linear_proj._extra_state',
    #             'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight',
    #             'decoder.layers.0.self_attention.linear_qkv.weight',
    #             'decoder.layers.0.self_attention.linear_qkv._extra_state',
    #             'decoder.layers.0.mlp.linear_fc1.layer_norm_weight',
    #             'decoder.layers.0.mlp.linear_fc1.weight',
    #             'decoder.layers.0.mlp.linear_fc1._extra_state',
    #             'decoder.layers.0.mlp.linear_fc2.weight',
    #             'decoder.layers.0.mlp.linear_fc2._extra_state',
    #             ...
    #             'decoder.final_layernorm.weight',
    #             'output_layer.weight',
    #             'output_layer._extra_state'
    # Convert and store the position embeddings.
    position_embeddings = tp_state_dicts[0]['model'].get("embedding.position_embeddings.weight", None)
    if position_embeddings:
        output_state_dict["transformer.position_embeddings.weight"] = position_embeddings.to(hf_config.torch_dtype)

    # Convert and store the word embeddings.
    word_embeddings = tp_state_dicts[tp_rank]['model'].get("embedding.word_embeddings.weight", None)
    word_embeddings = word_embeddings.to(hf_config.torch_dtype)
    output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings
    # Reset the vocab size
    hf_config.vocab_size = word_embeddings.shape[0]

    # Transformer Layers
    print("Converting transformer layers")

    for key, val in tp_state_dicts[tp_rank]['model'].items():
        if val is None:
            assert 'extra_state' in key, "weight/bias shouldn't be None except for extra_state in mcore"
            continue
        if "_extra_state" in key:
            continue

        # Match the name
        layer_match_res = layer_re.match(key)
        # Skip if that's not a layer
        if layer_match_res is None:
            continue
        # The index of the layer
        layer_idx = int(layer_match_res.group(1))
        # The name of the operation.
        op_name = layer_match_res.group(2)
        # Is it a weight or a bias?
        weight_or_bias = layer_match_res.group(3)
        # The name of the layer
        layer_name = f"{prefix_name}.layers.{layer_idx}"

        params = val.to(hf_config.torch_dtype)

        # For layernorm(s), simply store the layer norm.
        if op_name.endswith("layer_norm") and weight_or_bias == 'weight':
            if op_name == "self_attention.linear_qkv.layer_norm":
                ln_name = "input_layernorm"
            elif op_name == "mlp.linear_fc1.layer_norm":
                ln_name = "post_attention_layernorm"
            else:
                assert False, f"Unrecognized op_name {op_name} for layer norm"
            output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params

        # Transpose the QKV matrix.
        elif op_name == "self_attention.linear_qkv" and weight_or_bias == 'weight':
            input_shape = params.size()
            shape = (heads, hidden_size_per_head, 3) + input_shape[1:]
            division = reduce(operator.mul, shape, 1)
            num_elements = params.numel()
            if num_elements != division:
                # model with gqa dont need to fix qkv ordering.
                output_state_dict[layer_name + ".self_attn.qkv_proj.weight"] = params
            else:
                out_val = fix_qwen_query_key_value_ordering(
                    params, checkpoint_version, 3, heads, hidden_size_per_head
                )
                # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
                out_val = out_val.contiguous()
                # Store.
                output_state_dict[layer_name + ".self_attn.qkv_proj.weight"] = out_val

        # Transpose the bias.
        elif op_name == "self_attention.linear_qkv" and weight_or_bias == "bias":
            out_val = fix_qwen_query_key_value_ordering(
                params, checkpoint_version, 3, heads, hidden_size_per_head
            )
            # Store. No change of shape.
            output_state_dict[layer_name + ".self_attn.qkv_proj.bias"] = out_val

        # Transpose the weights.
        elif weight_or_bias == "weight":
            out_name = mcore_to_transformers[op_name]
            output_state_dict[layer_name + out_name + "weight"] = params

        # Copy the bias.
        # Ignore them
        elif weight_or_bias == "bias":
            pass

        # Copy the Rotary Embedding
        else:
            out_name = mcore_to_transformers[op_name]
            output_state_dict[layer_name + out_name] = params

    if hf_config.num_hidden_layers != (layer_idx + 1):
        raise ValueError(f"Expected {hf_config.num_hidden_layers} layers but found {layer_idx + 1}")

    # The final layernorm.
    print("Converting final layernorm")
    final_norm_weight = tp_state_dicts[0]['model'].get("decoder.final_layernorm.weight", None)
    output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype)

    # For LM head, transformers' wants the matrix to weight embeddings.
    print("Converting LM head")
    params = tp_state_dicts[tp_rank]['model'].get('output_layer.weight', None)
    output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype)

    # It should be done!
    print("Conversion from Megatron-Core to Transformers is done!")

    return output_state_dict