def split_and_save_weight()

in maga_transformer/utils/smooth_quant_convert/llama/convert.py [0:0]


def split_and_save_weight(i, saved_dir, factor, key, val, act_range, config):
    saved_dir = {}
    # The split_factor indicates the number of ranks to implement
    # distributed GEMMs. For Tensor Parallelism, each rank/GPU works
    # on split_hidden_dim // split_factor channels.

    int8_outputs = config.get("int8_outputs", None)
    multi_query_mode = config.get("multi_query_mode", False)
    local_dim = config.get("local_dim", None)

    save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"

    if "input_layernorm.weight" in key or "input_layernorm.bias" in key or \
        "attention.dense.bias" in key or "post_layernorm.weight" in key or \
        "post_attention_layernorm.bias" in key or "mlp.dense_4h_to_h.bias" in key or \
        "final_layernorm.weight" in key or "final_layernorm.bias" in key:

        # shared weights, only need to convert the weights of rank 0
        if i == 0:
            save_val(val, saved_dir, key)

    elif "attention.dense.weight" in key or "mlp.proj.weight" in key:
        split_dim = 0
        split_vals = np.split(val, factor, axis=split_dim)
        # save_split(split_vals, saved_dir, key, i, factor)
        if act_range is not None and int8_outputs == "all":
            base_key = key.replace(".weight", "")
            vals_i8 = generate_int8(val, act_range)
            write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)

    elif "mlp.fc.weight" in key or "mlp.gate.weight" in key:
        split_dim = -1
        split_vals = np.split(val, factor, axis=split_dim)
        # save_split(split_vals, saved_dir, key, i, factor)
        if act_range is not None and int8_outputs == "all":
            base_key = key.replace(".weight", "")
            vals_i8 = generate_int8(val, act_range)
            write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)

    elif "attention.query_key_value.weight" in key:
        hidden_dim = val.shape[0]
        if local_dim is None:
            local_dim = val.shape[-1] // 3
        if multi_query_mode:
            head_size = (val.shape[-1] - local_dim) // 2
            val = val.reshape(hidden_dim, local_dim + 2 * head_size)
            w_q, w_k, w_v = np.split(val, [local_dim, local_dim + head_size],
                                     axis=-1)
            w_q_split = np.split(w_q, factor, axis=-1)
            w_k_split = np.split(w_k, factor, axis=-1)
            w_v_split = np.split(w_v, factor, axis=-1)
            split_vals = [
                np.concatenate((w_q_split[ii], w_k_split[ii], w_v_split[ii]),
                               axis=-1) for ii in range(factor)
            ]
            split_dim = -1
        else:
            val = val.reshape(hidden_dim, 3, local_dim)
            split_dim = -1
            split_vals = np.split(val, factor, axis=split_dim)
        # save_split(split_vals, saved_dir, key, i, factor)
        if save_int8:
            base_key = key.replace(".weight", "")
            vals_i8 = generate_int8(val,
                                    act_range,
                                    is_qkv=True,
                                    multi_query_mode=multi_query_mode)
            write_int8(vals_i8,
                       saved_dir,
                       base_key,
                       split_dim,
                       i,
                       factor,
                       is_qkv=True,
                       multi_query_mode=multi_query_mode)
    elif "attention.dense.smoother" in key or "mlp.proj.smoother" in key:
        split_vals = np.split(val, factor, axis=0)
        save_split(split_vals, saved_dir, key, i, factor)

    else:
        print(f"[WARNING] {key} not handled by converter")
    return saved_dir