in maga_transformer/utils/smooth_quant_convert/llama/convert.py [0:0]
def split_and_save_weight(i, saved_dir, factor, key, val, act_range, config):
saved_dir = {}
# The split_factor indicates the number of ranks to implement
# distributed GEMMs. For Tensor Parallelism, each rank/GPU works
# on split_hidden_dim // split_factor channels.
int8_outputs = config.get("int8_outputs", None)
multi_query_mode = config.get("multi_query_mode", False)
local_dim = config.get("local_dim", None)
save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"
if "input_layernorm.weight" in key or "input_layernorm.bias" in key or \
"attention.dense.bias" in key or "post_layernorm.weight" in key or \
"post_attention_layernorm.bias" in key or "mlp.dense_4h_to_h.bias" in key or \
"final_layernorm.weight" in key or "final_layernorm.bias" in key:
# shared weights, only need to convert the weights of rank 0
if i == 0:
save_val(val, saved_dir, key)
elif "attention.dense.weight" in key or "mlp.proj.weight" in key:
split_dim = 0
split_vals = np.split(val, factor, axis=split_dim)
# save_split(split_vals, saved_dir, key, i, factor)
if act_range is not None and int8_outputs == "all":
base_key = key.replace(".weight", "")
vals_i8 = generate_int8(val, act_range)
write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)
elif "mlp.fc.weight" in key or "mlp.gate.weight" in key:
split_dim = -1
split_vals = np.split(val, factor, axis=split_dim)
# save_split(split_vals, saved_dir, key, i, factor)
if act_range is not None and int8_outputs == "all":
base_key = key.replace(".weight", "")
vals_i8 = generate_int8(val, act_range)
write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)
elif "attention.query_key_value.weight" in key:
hidden_dim = val.shape[0]
if local_dim is None:
local_dim = val.shape[-1] // 3
if multi_query_mode:
head_size = (val.shape[-1] - local_dim) // 2
val = val.reshape(hidden_dim, local_dim + 2 * head_size)
w_q, w_k, w_v = np.split(val, [local_dim, local_dim + head_size],
axis=-1)
w_q_split = np.split(w_q, factor, axis=-1)
w_k_split = np.split(w_k, factor, axis=-1)
w_v_split = np.split(w_v, factor, axis=-1)
split_vals = [
np.concatenate((w_q_split[ii], w_k_split[ii], w_v_split[ii]),
axis=-1) for ii in range(factor)
]
split_dim = -1
else:
val = val.reshape(hidden_dim, 3, local_dim)
split_dim = -1
split_vals = np.split(val, factor, axis=split_dim)
# save_split(split_vals, saved_dir, key, i, factor)
if save_int8:
base_key = key.replace(".weight", "")
vals_i8 = generate_int8(val,
act_range,
is_qkv=True,
multi_query_mode=multi_query_mode)
write_int8(vals_i8,
saved_dir,
base_key,
split_dim,
i,
factor,
is_qkv=True,
multi_query_mode=multi_query_mode)
elif "attention.dense.smoother" in key or "mlp.proj.smoother" in key:
split_vals = np.split(val, factor, axis=0)
save_split(split_vals, saved_dir, key, i, factor)
else:
print(f"[WARNING] {key} not handled by converter")
return saved_dir