maga_transformer/models/llama_weight.py (281 lines of code) (raw):
import functools
import logging
from typing import List
import torch
from typing import List
from einops import rearrange
from maga_transformer.utils.model_weight import (W, CkptWeightInfo, WeightStyle, concat_1,
concat_0, identity, sp_0, sp_head_lora, sp_id, sp_neg1, zeros, transpose, merge_qkv_lora_A,
merge_qkv_lora_B, shift_one, merge_qkv_b)
from maga_transformer.model_loader.weight_module import AtomicWeight, WeightModule
from maga_transformer.model_loader.attn_weight import AttnAtomicWeight, AttnConfig
from maga_transformer.model_loader.model_weight_info import ModelWeightInfo, ModelDeployWeightInfo
from maga_transformer.model_loader.ffn_weight import FfnAtomicWeight, FfnWeight, FfnConfig
# permute for sliced rotary
def permute(w, head_num, dim1, dim2):
return w.view(head_num, dim1 // head_num // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
def merge_qkv(ts, hidden_size, head_num_kv, head_num):
q, k, v = ts
q = permute(q, head_num, hidden_size, hidden_size)
k = permute(k, head_num_kv, head_num_kv * hidden_size // head_num, hidden_size)
qkv_weight = torch.concat([q.T, k.T, v.T], dim=1).contiguous()
return qkv_weight
def merge_qkv_hf(ts: List[torch.Tensor], hidden_size, head_num_kv, head_num):
q, k, v = ts
qkv_weight = torch.concat([q.T, k.T, v.T], dim=1).contiguous()
return qkv_weight
def qkv_rerange(ts, hidden_size, head_num_kv, head_num):
num_key_value_groups = int(head_num // head_num_kv)
size_per_head = int(hidden_size / head_num)
w = rearrange(ts[0].T, "q (h gs d) -> q h gs d",
gs=2 + num_key_value_groups,
d=size_per_head)
wq = w[..., : num_key_value_groups, :].reshape(w.shape[0], -1)
wk = w[..., -2, :].reshape(w.shape[0], -1)
wv = w[..., -1, :].reshape(w.shape[0], -1)
return torch.concat([wq, wk, wv], dim=1)
def qkv_transpose(ts, hidden_size):
return ts[0].reshape(hidden_size, -1)
class DefaultWeightNames:
WQ = 'layers.{i}.attention.wq.weight'
WK = 'layers.{i}.attention.wk.weight'
WV = 'layers.{i}.attention.wv.weight'
WO = 'layers.{i}.attention.wo.weight'
FFW1 = 'layers.{i}.feed_forward.w1.weight'
FFW2 = 'layers.{i}.feed_forward.w2.weight'
FFW3 = 'layers.{i}.feed_forward.w3.weight'
ATTEN_NORM = 'layers.{i}.attention_norm.weight'
FFN_NORM = 'layers.{i}.ffn_norm.weight'
TOKEN_EMBEDDING = 'tok_embeddings.weight'
NORM = 'norm.weight'
OUTPUT = 'output.weight'
class HfWeightNames:
WQ = 'model.layers.{i}.self_attn.q_proj.weight'
WK = 'model.layers.{i}.self_attn.k_proj.weight'
WV = 'model.layers.{i}.self_attn.v_proj.weight'
WO = 'model.layers.{i}.self_attn.o_proj.weight'
FFW1 = 'model.layers.{i}.mlp.gate_proj.weight'
FFW2 = 'model.layers.{i}.mlp.down_proj.weight'
FFW3 = 'model.layers.{i}.mlp.up_proj.weight'
ATTEN_NORM = 'model.layers.{i}.input_layernorm.weight'
FFN_NORM = 'model.layers.{i}.post_attention_layernorm.weight'
TOKEN_EMBEDDING = 'model.embed_tokens.weight'
NORM = 'model.norm.weight'
OUTPUT = 'lm_head.weight'
class SQWeightNames(HfWeightNames):
W_QKV = 'model.layers.{i}.attention.query_key_value.weight.int8.col'
W_QKV_S = 'model.layers.{i}.attention.query_key_value.scale_w_quant_orig.col'
WO = 'model.layers.{i}.attention.dense.weight.int8.col'
WO_S = 'model.layers.{i}.attention.dense.scale_w_quant_orig.col'
FFW1 = 'model.layers.{i}.mlp.fc.weight.int8.col'
FFW1_S = 'model.layers.{i}.mlp.fc.scale_w_quant_orig.col'
FFW2 = 'model.layers.{i}.mlp.proj.weight.int8.col'
FFW2_S = 'model.layers.{i}.mlp.proj.scale_w_quant_orig.col'
FFW3 = 'model.layers.{i}.mlp.gate.weight.int8.col'
FFW3_S = 'model.layers.{i}.mlp.gate.scale_w_quant_orig.col'
FFNW2_Smoother = 'model.layers.{i}.mlp.proj.smoother'
WO_Smoother = 'model.layers.{i}.attention.dense.smoother'
FFN_NORM = 'model.layers.{i}.post_layernorm.weight'
class YiWeightNames(HfWeightNames):
ATTEN_NORM = 'model.layers.{i}.ln1.weight'
FFN_NORM = 'model.layers.{i}.ln2.weight'
class BaichuanWeightNames(HfWeightNames):
W_QKV = 'model.layers.{i}.self_attn.W_pack.weight'
class InternlmWeightNames(HfWeightNames):
BQ = 'model.layers.{i}.self_attn.q_proj.bias'
BK = 'model.layers.{i}.self_attn.k_proj.bias'
BV = 'model.layers.{i}.self_attn.v_proj.bias'
BO = 'model.layers.{i}.self_attn.o_proj.bias'
class Internlm2WeightNames:
W_QKV = 'model.layers.{i}.attention.wqkv.weight'
WO = 'model.layers.{i}.attention.wo.weight'
FFW1 = 'model.layers.{i}.feed_forward.w1.weight'
FFW2 = 'model.layers.{i}.feed_forward.w2.weight'
FFW3 = 'model.layers.{i}.feed_forward.w3.weight'
ATTEN_NORM = 'model.layers.{i}.attention_norm.weight'
FFN_NORM = 'model.layers.{i}.ffn_norm.weight'
TOKEN_EMBEDDING = 'model.tok_embeddings.weight'
NORM = 'model.norm.weight'
OUTPUT = 'output.weight'
class GemmaWeightNames(HfWeightNames):
OUTPUT = 'model.embed_tokens.weight'
class CohereWeightNames:
WQ = 'model.layers.{i}.self_attn.q_proj.weight'
WK = 'model.layers.{i}.self_attn.k_proj.weight'
WV = 'model.layers.{i}.self_attn.v_proj.weight'
WO = 'model.layers.{i}.self_attn.o_proj.weight'
FFW1 = 'model.layers.{i}.mlp.gate_proj.weight'
FFW2 = 'model.layers.{i}.mlp.down_proj.weight'
FFW3 = 'model.layers.{i}.mlp.up_proj.weight'
ATTEN_NORM = 'model.layers.{i}.input_layernorm.weight'
Q_NORM = 'model.layers.{i}.self_attn.q_norm.weight'
K_NORM = 'model.layers.{i}.self_attn.k_norm.weight'
NORM = 'model.norm.weight'
TOKEN_EMBEDDING = 'model.embed_tokens.weight'
class LlamaWeightInfo(ModelDeployWeightInfo):
def __init__(self, config, tp_size, tp_rank, prefix=''):
super().__init__(config, tp_size, tp_rank)
self._names = None
self._merge_qkv = None
self._merge_qkv_b = None
self._prefix = prefix
@property
def support_lora(self):
return True
def _process_meta(self, meta_dicts, weight_keys):
if self._quant_algo.isSmoothQuant() and SQWeightNames.W_QKV.format(i='0') in weight_keys:
logging.info('load hf llama smooth quant weight')
self._names = SQWeightNames
self.weight_style = WeightStyle.RTP_SMOOTH_LLM_STYLE
elif Internlm2WeightNames.W_QKV.format(i='0') in weight_keys:
logging.info('load internlm2 style weight')
self._names = Internlm2WeightNames
self._merge_qkv = qkv_rerange
elif YiWeightNames.FFN_NORM.format(i='0') in weight_keys:
logging.info('load Yi style weight')
self._names = YiWeightNames
self._merge_qkv = merge_qkv_hf
elif BaichuanWeightNames.W_QKV.format(i='0') in weight_keys:
logging.info('load baichuan style weight')
self._names = BaichuanWeightNames
self._merge_qkv = None
elif InternlmWeightNames.BQ.format(i='0') in weight_keys:
logging.info('load internlm style weight')
self._names = InternlmWeightNames
self._merge_qkv = merge_qkv_hf
self._merge_qkv_b = merge_qkv_b
elif self._prefix + DefaultWeightNames.OUTPUT in weight_keys:
logging.info('load default llama1 style weight')
self._names = DefaultWeightNames
self._merge_qkv = merge_qkv
# when use llama3.2 1b, lm_head is shared with embedding
elif self._prefix + HfWeightNames.FFN_NORM.format(i='0') in weight_keys:
logging.info('load hf llama1 style weight')
self._names = HfWeightNames
self._merge_qkv = merge_qkv_hf
elif self._prefix + HfWeightNames.FFN_NORM.format(i='0') not in weight_keys:
logging.info('load cohere style weight')
self._names = CohereWeightNames
self._merge_qkv = merge_qkv_hf
else:
raise Exception('unknown weights format')
def _get_weight_info(self):
weights = [
AtomicWeight(W.embedding, [CkptWeightInfo(self._prefix + self._names.TOKEN_EMBEDDING, concat_1)], identity),
AtomicWeight(W.final_ln_gamma, [CkptWeightInfo(self._prefix + self._names.NORM, identity)], identity),
AtomicWeight(W.final_ln_beta, [], functools.partial(zeros, shape=[self._hidden_size])),
]
attn_config = AttnConfig(
hidden_size=self._hidden_size,
size_per_head=self._size_per_head,
head_num=self._head_num,
head_num_kv=self._head_num_kv)
ffn_config = FfnConfig(
is_gated_activation=self._is_gated_activation,
inter_padding_size=self._inter_padding_size,
is_moe=False
)
if self._names == CohereWeightNames:
weights.append(AtomicWeight(W.lm_head, [CkptWeightInfo(self._prefix + self._names.TOKEN_EMBEDDING, identity)], identity))
else:
weights.append(AtomicWeight(W.lm_head, [CkptWeightInfo(self._prefix + self._names.OUTPUT, concat_0)], identity))
layer_weights: list[WeightModule] = [
AtomicWeight(W.pre_ln_gamma, [CkptWeightInfo(self._prefix + self._names.ATTEN_NORM, identity)], identity),
AtomicWeight(W.post_ln_gamma, [CkptWeightInfo(self._prefix + self._names.FFN_NORM, identity)], identity),
]
if self.weight_style == WeightStyle.RTP_SMOOTH_LLM_STYLE:
layer_weights.extend([
AttnAtomicWeight(W.attn_o_w, [CkptWeightInfo(self._prefix + self._names.WO.removesuffix(".int8.col"), identity)], transpose,
config=attn_config),
FfnWeight(sub_weights=[
FfnAtomicWeight(W.ffn_w1, [CkptWeightInfo(self._prefix + self._names.FFW1.removesuffix(".int8.col"), identity)], transpose,
config=ffn_config),
FfnAtomicWeight(W.ffn_w3, [CkptWeightInfo(self._prefix + self._names.FFW3.removesuffix(".int8.col"), identity)], transpose,
config=ffn_config),
FfnAtomicWeight(W.ffn_w2, [CkptWeightInfo(self._prefix + self._names.FFW2.removesuffix(".int8.col"), identity)], transpose,
config=ffn_config),
], config=ffn_config),
AttnAtomicWeight(W.attn_qkv_w, [CkptWeightInfo(self._prefix + self._names.W_QKV.removesuffix(".int8.col"),
functools.partial(qkv_transpose, hidden_size=self._hidden_size))],
transpose, config=attn_config),
]
)
else:
layer_weights.extend([
AttnAtomicWeight(W.attn_o_w, [CkptWeightInfo(self._prefix + self._names.WO, concat_1)], transpose,
config=attn_config,
lora_a_process_func=transpose, lora_b_process_func=transpose,
lora_a_split_func=sp_0, lora_b_split_func=sp_id),
FfnWeight(sub_weights=[
FfnAtomicWeight(W.ffn_w1, [CkptWeightInfo(self._prefix + self._names.FFW1, concat_0)], transpose,
config=ffn_config,
lora_a_process_func=transpose, lora_b_process_func=transpose,
lora_a_split_func=sp_id, lora_b_split_func=sp_neg1),
FfnAtomicWeight(W.ffn_w3, [CkptWeightInfo(self._prefix + self._names.FFW3, concat_0)], transpose,
config=ffn_config,
lora_a_process_func=transpose, lora_b_process_func=transpose,
lora_a_split_func=sp_id, lora_b_split_func=sp_neg1),
FfnAtomicWeight(W.ffn_w2, [CkptWeightInfo(self._prefix + self._names.FFW2, concat_1)], transpose,
config=ffn_config, lora_a_process_func=transpose, lora_b_process_func=transpose,
lora_a_split_func=sp_0, lora_b_split_func=sp_id)
], config=ffn_config)]
)
if self._names == CohereWeightNames:
layer_weights.append(AtomicWeight(W.qk_ln_gamma,
[CkptWeightInfo(self._prefix + self._names.Q_NORM, identity),
CkptWeightInfo(self._prefix + self._names.K_NORM, identity)], concat_0))
else:
layer_weights.append(AtomicWeight(W.post_ln_gamma, [CkptWeightInfo(self._prefix + self._names.FFN_NORM, identity)], identity))
if self._names == InternlmWeightNames:
layer_weights.append(
AttnAtomicWeight(W.attn_qkv_b,
[CkptWeightInfo(self._prefix + self._names.BQ, identity),
CkptWeightInfo(self._prefix + self._names.BK, identity),
CkptWeightInfo(self._prefix + self._names.BV, identity)],
functools.partial(self._merge_qkv_b), config=attn_config))
layer_weights.append(AttnAtomicWeight(W.attn_o_b, [CkptWeightInfo(self._prefix + self._names.BO, identity)], identity, config=attn_config))
if self._merge_qkv is not None:
if hasattr(self._names, 'W_QKV'):
infos = [CkptWeightInfo(self._prefix + self._names.W_QKV, identity)]
lora_a_process_func = identity
lora_b_process_func = identity
else:
infos = [CkptWeightInfo(self._prefix + self._names.WQ, concat_0),
CkptWeightInfo(self._prefix + self._names.WK, concat_0),
CkptWeightInfo(self._prefix + self._names.WV, concat_0)]
lora_a_process_func = functools.partial(merge_qkv_lora_A, allow_empty=True, hidden_size=self._hidden_size, head_num=self._head_num, head_num_kv=self._head_num_kv, size_per_head=self._size_per_head)
lora_b_process_func = functools.partial(merge_qkv_lora_B, allow_empty=True, hidden_size=self._hidden_size, head_num=self._head_num, head_num_kv=self._head_num_kv, size_per_head=self._size_per_head)
layer_weights.append(
AttnAtomicWeight(W.attn_qkv_w, infos,
functools.partial(self._merge_qkv,
hidden_size=self._hidden_size,
head_num_kv=self._head_num_kv,
head_num=self._head_num),
config=attn_config,
lora_a_process_func=lora_a_process_func, lora_b_process_func=lora_b_process_func,
lora_a_split_func=sp_id, lora_b_split_func=sp_head_lora))
else:
layer_weights.append(
AttnAtomicWeight(W.attn_qkv_w, [CkptWeightInfo(self._prefix + self._names.W_QKV, identity)], transpose,
config=attn_config,
lora_a_process_func=transpose, lora_b_process_func=transpose,
lora_a_split_func=sp_id, lora_b_split_func=sp_head_lora))
return ModelWeightInfo(layer_weights=layer_weights, weights=weights)
class GemmaWeightInfo(LlamaWeightInfo):
def __init__(self, config, tp_size, tp_rank):
super().__init__(config, tp_size, tp_rank)
def _process_meta(self, meta_dicts, weight_keys):
logging.info('load gemma style weight')
self._names = GemmaWeightNames
self._merge_qkv = merge_qkv_hf
def _check_layernorm(self, weight):
if isinstance(weight, list):
return
if ("layernorm" in weight.name) and ("gamma" in weight.name):
logging.info(f"gemma adds shift 1 to {weight.name}")
weight.process_fun = shift_one
def _get_weight_info(self):
weight_info = super()._get_weight_info()
for layer_weight in weight_info.layer_weights:
self._check_layernorm(layer_weight)
for weight in weight_info.weights:
self._check_layernorm(weight)
return weight_info