def split_and_save_weight()

in maga_transformer/utils/smooth_quant_convert/qwen/convert.py [0:0]


def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals,
                          storage_type, act_range, config):
    saved_dir = {}
    use_attention_nemo_shape = config.get("use_attention_nemo_shape", False)
    split_gated_activation = config.get("split_gated_activation", False)
    num_attention_heads = config.get("num_attention_heads", 0)
    tp_size = config.get("tp_size", 1)
    int8_outputs = config.get("int8_outputs", None)
    multi_query_mode = config.get("multi_query_mode", False)
    local_dim = config.get("local_dim", None)

    save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"

    if not key.endswith(".smoother"):
        if not isinstance(vals, list):
            vals = [vals]

        if config.get("transpose_weights", False) and vals[0].ndim == 2:
            vals = [val.T for val in vals]
        if "layernorm.weight" in key and config.get("apply_layernorm_1p",
                                                    False):
            vals = [val + 1.0 for val in vals]
        vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals]
    else:
        vals = torch_to_numpy(vals.cpu())

    if "ln_1.weight" in key or "ln_1.bias" in key or \
            "attn.c_attn.bias" in key or \
            "ln_2.weight" in key or "ln_2.bias" in key or \
            "mlp.c_proj.bias" in key or "ln_f.weight" in key:
        # "final_layernorm.weight" in key or "final_layernorm.bias" in key:

        # shared weights, only need to convert the weights of rank 0
        if tp_rank == 0:
            save_val(vals[0], saved_dir, key)

    elif "attn.c_proj.weight" in key or "mlp.c_proj.weight" in key:
        cat_dim = 0
        val = np.concatenate(vals, axis=cat_dim)
        split_vals = np.split(val, split_factor, axis=cat_dim)
        # save_split(split_vals, saved_dir, key, tp_rank, split_factor)
        if act_range is not None and int8_outputs == "all":
            base_key = key.replace(".weight", "")
            vals_i8 = generate_int8(val,
                                    act_range,
                                    multi_query_mode=multi_query_mode)
            write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank,
                       split_factor)

    elif "mlp.w1.weight" in key or "mlp.w2.weight" in key or "mlp.w1.bias" in key or "mlp.w2.bias" in key:
        if split_gated_activation:
            splits = [np.split(val, 2, axis=-1) for val in vals]
            vals, gates = list(zip(*splits))
        cat_dim = -1
        val = np.concatenate(vals, axis=cat_dim)
        split_vals = np.split(val, split_factor, axis=cat_dim)
        # save_split(split_vals, saved_dir, key, tp_rank, split_factor)
        if act_range is not None and int8_outputs == "all":
            base_key = key.replace(".weight", "")
            vals_i8 = generate_int8(val,
                                    act_range,
                                    multi_query_mode=multi_query_mode)
            write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank,
                       split_factor)

        if split_gated_activation:
            assert not save_int8
            prefix, dot, suffix = key.rpartition(".")
            key = prefix + ".gate" + dot + suffix

            gate = np.concatenate(gates, axis=cat_dim)
            split_vals = np.split(gate, split_factor, axis=cat_dim)
            save_split(split_vals, saved_dir, key, tp_rank, split_factor)

    elif "attn.c_attn.bias" in key:
        if local_dim is None:
            local_dim = vals[0].shape[-1] // 3

        if multi_query_mode:
            val = vals[0]
            # out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim
            b_q, b_kv = np.split(val, [local_dim], axis=-1)
            b_q_split = np.split(b_q, split_factor, axis=-1)
            split_vals = [np.concatenate((i, b_kv), axis=-1) for i in b_q_split]
        else:
            if use_attention_nemo_shape:
                head_num = num_attention_heads // tp_size
                size_per_head = local_dim // num_attention_heads
                nemo_shape = (head_num, 3, size_per_head)
                vals = [val.reshape(nemo_shape) for val in vals]
                vals = [val.transpose(1, 0, 2) for val in vals]

            vals = [val.reshape(3, local_dim) for val in vals]
            val = np.concatenate(vals, axis=-1)
            split_vals = np.split(val, split_factor, axis=-1)
        save_split(split_vals, saved_dir, key, tp_rank, split_factor)

    elif "attn.c_attn.weight" in key:
        hidden_dim = vals[0].shape[0]
        if local_dim is None:
            local_dim = vals[0].shape[-1] // 3
        if multi_query_mode:
            val = vals[0]
            # out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim
            head_size = (val.shape[-1] - local_dim) // 2
            val = val.reshape(hidden_dim, local_dim + 2 * head_size)
            w_q, w_kv = np.split(val, [local_dim], axis=-1)
            w_q_split = np.split(w_q, split_factor, axis=-1)
            split_vals = [np.concatenate((i, w_kv), axis=-1) for i in w_q_split]
        else:
            if use_attention_nemo_shape:
                head_num = num_attention_heads // tp_size
                size_per_head = hidden_dim // num_attention_heads
                vals = [
                    val.reshape(hidden_dim, head_num, 3, size_per_head)
                    for val in vals
                ]
                vals = [val.transpose(0, 2, 1, 3) for val in vals]

            vals = [val.reshape(hidden_dim, 3, local_dim) for val in vals]
            cat_dim = -1
            val = np.concatenate(vals, axis=cat_dim)
            split_vals = np.split(val, split_factor, axis=cat_dim)
        # save_split(split_vals, saved_dir, key, tp_rank, split_factor)
        if save_int8:
            base_key = key.replace(".weight", "")
            vals_i8 = generate_int8(val,
                                    act_range,
                                    is_qkv=True,
                                    multi_query_mode=multi_query_mode)
            write_int8(vals_i8,
                       saved_dir,
                       base_key,
                       cat_dim,
                       tp_rank,
                       split_factor,
                       kv_cache_only=int8_outputs == "kv_cache_only")

    elif "attn.c_proj.smoother" in key or "mlp.c_proj.smoother" in key:
        split_vals = np.split(vals, split_factor, axis=0)
        save_split(split_vals, saved_dir, key, tp_rank, split_factor)
    else:
        print(f"[WARNING] {key} not handled by converter")
    return saved_dir