def convert_checkpoint_from_megatron_to_transformers()

in toolkits/model_checkpoints_convertor/mistral/hf2mcore_mixtral.py [0:0]


def convert_checkpoint_from_megatron_to_transformers(args):
    """
    Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints
    with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards
    using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of
    `convert_megatron_gpt2_checkpoint.py`

    Args:
        args (argparse.Namespace): the arguments to the script
    """
    os.makedirs(args.save_path, exist_ok=True)

    # Saving config and tokenzier files
    os.system("cp -rf " + args.load_path + "/*.json " + args.save_path)
    os.system("cp -rf " + args.load_path + "/tokenizer.model " + args.save_path)

    tracker_filepath = os.path.join(args.load_path, "latest_checkpointed_iteration.txt")
    with open(tracker_filepath, "r") as f:
        tag = f.readline()

    args.load_path = os.path.join(args.load_path, tag)
    import glob
    if glob.glob(args.load_path + "/mp_rank*/distrib*"):
        # if os.path.exists(args.load_path+"/mp_rank*/distrib*"):
        user_input = input(
            "Optimizer states detected. Will remove distrib* files. yes (remove and continue) / no (stop programme): ")
        if user_input == 'yes':
            os.system("rm -rf " + args.load_path + "/mp_rank*/distrib*")
        else:
            raise RuntimeError("Optimizer states are not removed. Save files to another folder and re-run.")

    # params dtype
    if args.target_params_dtype == "fp16":
        dtype = torch.float16
    elif args.target_params_dtype == "bf16":
        dtype = torch.bfloat16
    else:
        dtype = torch.float32

    config = MixtralConfig()

    output_state_dict = {}

    # checkpoint_version = state_dict.get("checkpoint_version", 3.0)
    tp_size = args.target_tensor_model_parallel_size
    pp_size = args.target_pipeline_model_parallel_size
    ep_size = args.target_expert_model_parallel_size

    # The regex to extract layer names.
    layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # Convert.
    print("Converting")

    # Embeddings
    print("Converting embeddings")
    tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, ep_size, 0)

    # import pdb
    # pdb.set_trace()

    # Convert and store the word embeddings.
    word_embeddings = []
    word_embeddings_layernorm_weight = []
    word_embeddings_layernorm_bias = []

    # import pdb
    # pdb.set_trace()
    embeddings = tp_state_dicts[0]["model"]["embedding.word_embeddings.weight"]
    for tp_rank in range(tp_size):
        embeddings = tp_state_dicts[tp_rank]["model"]["embedding.word_embeddings.weight"]
        word_embeddings.append(embeddings)

    word_embeddings = torch.cat(word_embeddings, dim=0)
    word_embeddings = word_embeddings.to(dtype)
    output_state_dict["model.embed_tokens.weight"] = word_embeddings.clone()
    # Reset the vocab size
    config.vocab_size = word_embeddings.shape[0]

    # Transformer Layers
    print("Converting transformer layers")
    # The number of heads.
    heads = config.num_attention_heads
    # The hidden_size per head.
    hidden_size_per_head = config.hidden_size // config.num_attention_heads
    num_layers = config.num_hidden_layers // pp_size

    hidden_size = config.hidden_size
    num_groups = config.num_key_value_heads

    for pp_rank in range(pp_size):
        if pp_size > 0:
            print(f"Converting pipeline parallel rank {pp_rank}")
            tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, ep_size, pp_rank)

        # The transformer.

        path = 'model'

        # Extract the layers.
        for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items():
            if key.endswith('_extra_state'):
                continue
            # deal with experts
            if 'linear_fc' in key:
                print(key)
                key_list = key.split('.')
                layer_id = int(key_list[2]) + pp_rank * num_layers
                expert_id = key_list[-3]
                dim = 1 if 'linear_fc2' in key else 0
                params = torch.cat(
                    [val]
                    + [
                        get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key]
                        for tp_rank in range(1, tp_size)
                    ],
                    dim=dim,
                ).to(dtype)

                if 'linear_fc2' in key:
                    output_state_dict[
                        f'model.layers.{layer_id}.block_sparse_moe.experts.{expert_id}.w2.weight'] = params
                else:
                    params_split = [torch.chunk(i, 2, 0) for i in torch.chunk(params, tp_size, 0)]
                    output_state_dict[
                        f'model.layers.{layer_id}.block_sparse_moe.experts.{expert_id}.w1.weight'] = torch.cat(
                        [i[0] for i in params_split])
                    output_state_dict[
                        f'model.layers.{layer_id}.block_sparse_moe.experts.{expert_id}.w3.weight'] = torch.cat(
                        [i[1] for i in params_split])

                continue

            new_key = key.replace('decoder.', '')
            if 'layer_norm_weight' in new_key:
                new_key += '.weight'
            # Match the name.
            m = layer_re.match(new_key)
            # Stop if that's not a layer
            if m is None:
                continue

            # The index of the layer.
            layer_idx = int(m.group(1)) + pp_rank * num_layers
            # The name of the operation.
            op_name = m.group(2)
            # Is it a weight or a bias?
            weight_or_bias = m.group(3)

            # The name of the layer.
            layer_name = f"model.layers.{layer_idx}"

            print(layer_name, op_name, weight_or_bias)

            if op_name + "." + weight_or_bias not in tensor_parallel_params_mg:
                params = val.to(dtype)
            else:
                dim = 1 if op_name in column_split_tensor_parallel_params_mg else 0
                params = torch.cat(
                    [val]
                    + [
                        get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key]
                        for tp_rank in range(1, tp_size)
                    ],
                    dim=dim,
                ).to(dtype)

            # For layernorm(s), simply store the layer norm.
            if op_name.endswith("layer_norm_weight") or op_name.endswith("layernorm"):
                ln_name = "input_layernorm" if op_name.endswith("layer_norm_weight") else "post_attention_layernorm"
                output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params.clone()
                continue

            # Transpose the QKV matrix.
            elif (
                    op_name == "attention.linear_qkv" or op_name == "self_attention.linear_qkv"
            ) and weight_or_bias == "weight":

                all_qkvs = [i.reshape(num_groups // args.target_tensor_model_parallel_size,
                                      (heads // num_groups * hidden_size_per_head + 2 * hidden_size_per_head),
                                      hidden_size) for i in
                            torch.chunk(params, args.target_tensor_model_parallel_size, 0)]
                split_size = heads // num_groups * hidden_size_per_head
                all_qs = torch.cat([i[:, :split_size, :].reshape(-1, hidden_size) for i in all_qkvs])
                all_kvs = torch.cat([i[:, split_size:, :].reshape(-1, hidden_size) for i in all_qkvs])

                checkpoint_version = 3.0
                out_q = megatron_to_transformers_fix_query_key_value_ordering(
                    all_qs,
                    checkpoint_version,
                    1,
                    heads,
                    hidden_size_per_head,
                )

                out_kv = megatron_to_transformers_fix_query_key_value_ordering(
                    all_kvs,
                    checkpoint_version,
                    2,
                    num_groups,
                    hidden_size_per_head,
                )
                out_kv = torch.chunk(out_kv, 2)

                output_state_dict[layer_name + f".self_attn.q_proj.weight"] = out_q.clone()
                output_state_dict[layer_name + f".self_attn.k_proj.weight"] = out_kv[0].clone()
                output_state_dict[layer_name + f".self_attn.v_proj.weight"] = out_kv[1].clone()

            # Transpose the weights.
            elif weight_or_bias == "weight":
                out_name = megatron_to_transformers[op_name]
                output_state_dict[layer_name + '.' + out_name + '.' + "weight"] = params.clone()

    if config.num_hidden_layers != (layer_idx + 1):
        raise ValueError(f"Expected {config.num_hidden_layers} layers but found {layer_idx + 1}")

    # The final layernorm.
    print("Converting final layernorm")
    params = get_element_from_dict_by_path(tp_state_dicts[0], str(path))
    try:
        output_state_dict["model.norm.weight"] = params["decoder.final_layernorm.weight"].to(dtype).clone()
    except:
        output_state_dict["model.norm.weight"] = params["decoder.final_norm.weight"].to(dtype).clone()

    # For LM head, transformers' wants the matrix to weight embeddings.
    print("Converting LM head")
    params = torch.cat([
        get_element_from_dict_by_path(tp_state_dicts[i]['model'], 'output_layer.weight')
        for i in range(tp_size)]
    )
    output_state_dict["lm_head.weight"] = params.to(dtype).clone()

    # It should be done!
    print("Conversion from Megatron-LM to Transformers is done!")

    # Print the structure of converted state dict.
    if args.print_checkpoint_structure:
        recursive_print(None, output_state_dict)

    max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size
    args.save_safetensors = False
    save_hfmodel(args, output_state_dict, max_shard_size)