def convert_checkpoint_from_transformers_to_megatron()

in toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen1.5_dense_gqa.py [0:0]


def convert_checkpoint_from_transformers_to_megatron(args):

    os.makedirs(args.save_path, exist_ok=True)

    # Saving config and tokenzier files
    os.system("cp -rf " + args.load_path + "/*.json " + args.save_path)
    os.system("cp -rf " + args.load_path + "/tokeniz* " + args.save_path)

    # Saving the tracker file
    tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt")
    with open(tracker_filepath, "w") as f:
        f.write("release")

    # create `release` dir in args.load_path
    release_dir = os.path.join(args.save_path, "release")
    os.makedirs(release_dir, exist_ok=True)
    config = AutoConfig.from_pretrained(args.load_path)
    # megatron args
    megatron_args = {
        "orig_vocab_size": config.vocab_size,
        "hidden_size": config.hidden_size,
        "num_layers": config.num_hidden_layers,
        "num_attention_heads": config.num_attention_heads,
        "tensor_model_parallel_size": args.target_tensor_model_parallel_size,
        "pipeline_model_parallel_size": args.target_pipeline_model_parallel_size
    }

    margs = types.SimpleNamespace()
    for k, v in megatron_args.items():
        setattr(margs, k, v)

    state_dict = AutoModelForCausalLM.from_pretrained(args.load_path).state_dict()
    internal_state_dict = {}
    for layer_id in range(config.num_hidden_layers):
        q_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.q_proj.weight']
        k_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.k_proj.weight']
        v_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.v_proj.weight']

        q_bias = state_dict['model.layers.'+str(layer_id)+'.self_attn.q_proj.bias']
        k_bias = state_dict['model.layers.' + str(layer_id) + '.self_attn.k_proj.bias']
        v_bias = state_dict['model.layers.' + str(layer_id) + '.self_attn.v_proj.bias']

        internal_state_dict['transformer.layers.' + str(layer_id) + '.self_attn.query.weight'] = q_weight
        internal_state_dict['transformer.layers.' + str(layer_id) + '.self_attn.key_value.weight'] = torch.cat(
            (k_weight, v_weight))

        internal_state_dict['transformer.layers.'+str(layer_id)+'.self_attn.query.bias'] =q_bias

        internal_state_dict['transformer.layers.'+str(layer_id)+'.self_attn.key_value.bias'] =\
                torch.cat((k_bias, v_bias))

        internal_state_dict['transformer.layers.' + str(layer_id) + '.self_attn.dense.weight'] = \
            state_dict['model.layers.' + str(layer_id) + '.self_attn.o_proj.weight']

        dense_h_to_4h_1_weight = state_dict[
            'model.layers.' + str(layer_id) + '.mlp.gate_proj.weight']

        dense_h_to_4h_2_weight = state_dict[
            'model.layers.' + str(layer_id) + '.mlp.up_proj.weight']

        internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.dense_h_to_4h_1.weight'] =\
            dense_h_to_4h_1_weight

        internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.dense_h_to_4h_2.weight'] =\
            dense_h_to_4h_2_weight

        internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.dense_4h_to_h.weight'] = state_dict[
            'model.layers.' + str(layer_id) + '.mlp.down_proj.weight']

        internal_state_dict['transformer.layers.' + str(layer_id) + '.input_layernorm.weight'] = state_dict[
            'model.layers.' + str(layer_id) + '.input_layernorm.weight']

        internal_state_dict['transformer.layers.' + str(layer_id) + '.post_attention_layernorm.weight'] = state_dict[
            'model.layers.' + str(layer_id) + '.post_attention_layernorm.weight']

    internal_state_dict["transformer.word_embeddings.weight"] = state_dict['model.embed_tokens.weight']
    internal_state_dict["transformer.final_layernorm.weight"] = state_dict['model.norm.weight']
    internal_state_dict["transformer.lm_head.weight"] = state_dict['lm_head.weight']

    output_state_dict = []
    for i in range(args.target_tensor_model_parallel_size):
        output_state_dict.append(OrderedDict())

    num_query_group = config.num_key_value_heads
    output_group_state_dict = []
    for i in range(num_query_group):
        output_group_state_dict.append({})

    if args.target_params_dtype == "fp16":
        dtype = torch.float16
    elif args.target_params_dtype == "bf16":
        dtype = torch.bfloat16
    else:
        dtype = torch.float32

    # Embedding layer
    print("converting embedding layer")
    word_embedding = internal_state_dict["transformer.word_embeddings.weight"].to(dtype)
    out_word_embed = torch.chunk(word_embedding, args.target_tensor_model_parallel_size, dim=0)
    for i in range(args.target_tensor_model_parallel_size):
        word_emb_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
        word_emb_dict["embedding.word_embeddings.weight"] = out_word_embed[i]

    print("converting output layer")
    lm_head = internal_state_dict["transformer.lm_head.weight"].to(dtype)
    out_lm_head = torch.chunk(lm_head, args.target_tensor_model_parallel_size, dim=0)
    for i in range(args.target_tensor_model_parallel_size):
        lm_head_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
        lm_head_dict["output_layer.weight"] = out_lm_head[i]

    print("converting transformer layers")
    if config.num_hidden_layers % args.target_pipeline_model_parallel_size != 0:
        raise ValueError(
            f"Number of layers ({config.num_hidden_layers}) must be divisible by number of pipeline parallelism"
            f" ({args.target_pipeline_model_parallel_size})"
        )

    num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size

    layer_re = re.compile("transformer.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    hidden_size = config.hidden_size
    num_groups = config.num_key_value_heads
    num_heads = config.num_attention_heads
    hidden_size_per_head = config.hidden_size // config.num_attention_heads

    for pp_rank in range(args.target_pipeline_model_parallel_size):
        layer_offset = pp_rank * num_layers
        if pp_rank > 0:
            output_state_dict = []
            for i in range(args.target_tensor_model_parallel_size):
                output_state_dict.append({})

            output_group_state_dict = []
            for i in range(num_query_group):
                output_group_state_dict.append({})

        for layer in range(num_layers):
            pp_layer_id = layer + layer_offset
            layers_to_copy = [
                layer_name
                for layer_name in internal_state_dict.keys()
                if layer_name.startswith(f"transformer.layers.{pp_layer_id}.")
            ]

            for layer_name in layers_to_copy:
                m = layer_re.match(layer_name)
                # Stop if that's not a layer
                if m is None:
                    break

                # The index of the layer.
                _ = int(m.group(1))
                # The name of the operation.
                op_name = m.group(2)
                # Is it a weight or a bias?
                weight_or_bias = m.group(3)

                params = internal_state_dict[layer_name].to(dtype)
                # handle layernorm
                if op_name.startswith("input_layernorm") and weight_or_bias == "weight":
                    out_name = "self_attention.linear_qkv"
                    layer_name = f"layers.{layer}.{out_name}.layer_norm_weight"

                elif op_name.startswith("post_attention_layernorm") and weight_or_bias == "weight":
                    out_name = "pre_mlp_layernorm"
                    layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}"

                # handle attention K, V, Q weights
                elif op_name.startswith("self_attn.query"):
                    # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D.
                    params = transformers_to_megatron_fix_query_key_value_ordering(
                        params,
                        3.0,
                        1,
                        num_heads,
                        hidden_size_per_head,
                    )
                    layer_name = f"layers.{layer}.{op_name}.{weight_or_bias}"

                elif op_name.startswith("self_attn.key_value"):
                    # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D.
                    params = transformers_to_megatron_fix_query_key_value_ordering(
                        params,
                        3.0,
                        2,
                        num_groups,
                        hidden_size_per_head,
                    )
                    layer_name = f"layers.{layer}.{op_name}.{weight_or_bias}"

                # handle attention and mlp weights
                elif weight_or_bias == "weight":
                    out_name = internal_to_output_mapping.get(op_name, None)
                    if out_name is None:
                        continue
                    layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}"

                # skip
                else:
                    continue

                if op_name + "." + weight_or_bias in tensor_parallel_params:
                    dim = 1 if op_name + "." + weight_or_bias in column_split_tensor_parallel_params else 0
                    params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=dim)

                for i in range(args.target_tensor_model_parallel_size):
                    params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
                    params_dict["decoder." + layer_name] = (
                        params[i].clone() if (
                                    op_name + "." + weight_or_bias in tensor_parallel_params) else params.clone()
                    )

            for i in range(args.target_tensor_model_parallel_size):
                params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")

                dense_h_to_4h_1_name = f'decoder.layers.{layer}.mlp.linear_fc1_1.weight'
                dense_h_to_4h_1_weight = params_dict[dense_h_to_4h_1_name]
                del params_dict[dense_h_to_4h_1_name]

                dense_h_to_4h_2_name = f'decoder.layers.{layer}.mlp.linear_fc1_2.weight'
                dense_h_to_4h_2_weight = params_dict[dense_h_to_4h_2_name]
                del params_dict[dense_h_to_4h_2_name]

                dense_h_to_4h_name = f'decoder.layers.{layer}.mlp.linear_fc1.weight'
                params_dict[dense_h_to_4h_name] = \
                    torch.cat([dense_h_to_4h_1_weight, dense_h_to_4h_2_weight], dim=0)

                self_attn_query_weight_name = f"decoder.layers.{layer}.self_attn.query.weight"
                query_weight = params_dict[self_attn_query_weight_name]
                self_attn_query_bias_name = f"decoder.layers.{layer}.self_attn.query.bias"
                query_bias = params_dict[self_attn_query_bias_name]
                del params_dict[self_attn_query_weight_name]
                del params_dict[self_attn_query_bias_name]
                self_attn_kv_weight_name = f"decoder.layers.{layer}.self_attn.key_value.weight"
                kv_weight = params_dict[self_attn_kv_weight_name]
                self_attn_kv_bias_name = f"decoder.layers.{layer}.self_attn.key_value.bias"
                kv_bias = params_dict[self_attn_kv_bias_name]
                del params_dict[self_attn_kv_weight_name]
                del params_dict[self_attn_kv_bias_name]

                group_query_weight = query_weight.view(num_groups // args.target_tensor_model_parallel_size,
                                                       num_heads // num_groups * hidden_size_per_head, hidden_size)

                group_query_bias = query_bias.view(num_groups // args.target_tensor_model_parallel_size, -1)


                group_kv_weight = kv_weight.view(num_groups // args.target_tensor_model_parallel_size,
                                                 2 * hidden_size_per_head, hidden_size)

                group_kv_bias = kv_bias.view(num_groups // args.target_tensor_model_parallel_size, -1)

                group_qkv_weight = torch.cat([group_query_weight, group_kv_weight], dim=1)
                params_dict["decoder." + f"layers.{layer}.self_attention.linear_qkv.weight"] = \
                    group_qkv_weight.view(-1, hidden_size)

                group_qkv_bias = torch.cat([group_query_bias, group_kv_bias], dim=1)
                params_dict["decoder." + f"layers.{layer}.self_attention.linear_qkv.bias"] = \
                    group_qkv_bias.view(-1)


        if pp_rank == args.target_pipeline_model_parallel_size - 1:
            # handle final layernorm
            for weight_or_bias in ["weight"]:
                params = internal_state_dict[f"transformer.final_layernorm.{weight_or_bias}"].to(dtype)
                layer_name = "decoder." + f"final_layernorm.{weight_or_bias}"
                for i in range(args.target_tensor_model_parallel_size):
                    params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
                    params_dict[layer_name] = params.clone()

            # add the embedding
            for i in range(args.target_tensor_model_parallel_size):
                params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
                params_dict["embedding.word_embeddings.weight"] = out_word_embed[i].clone()

            # add the LM head
            for i in range(args.target_tensor_model_parallel_size):
                params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
                params_dict["output_layer.weight"] = out_lm_head[i].clone()


        # saving the state dict as per the tp_rank and pp_rank
        for tp_rank in range(args.target_tensor_model_parallel_size):
            output_state_dict[tp_rank]["checkpoint_version"] = 3.0
            output_state_dict[tp_rank]["args"] = margs
            checkpoint_dir = (
                f"mp_rank_{tp_rank:02d}"
                if args.target_pipeline_model_parallel_size == 1
                else f"mp_rank_{tp_rank:02d}_{pp_rank:03d}"
            )

            checkpoint_name = "model_optim_rng.pt"
            checkpoint_dir = os.path.join(release_dir, checkpoint_dir)
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
            torch.save(clone_state_dict(output_state_dict[tp_rank]), checkpoint_path)