maga_transformer/model_loader/model_weight_info.py (394 lines of code) (raw):
import functools
import gc
import logging
import os
import torch
from maga_transformer.utils.ckpt_file_info import CkptFileInfo
from typing import List, Tuple, Union, Optional, Dict, Any
from maga_transformer.utils.database import BaseDatabase, CkptDatabase
from maga_transformer.utils.model_weight import W, CkptWeightInfo, WeightStyle, choose_available, identity, tolerate_failed
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
from maga_transformer.model_loader.load_config import LoadConfig
from maga_transformer.model_loader.weight_module import WeightModule, AtomicWeight, CompositeWeight
from maga_transformer.model_loader.ffn_weight import FfnConfig, FfnWeight, MoeWithSharedWeight
from maga_transformer.model_loader.attn_weight import AttnConfig
class ModelWeightInfo:
layer_weights: Union[List[WeightModule], List[List[WeightModule]]]
weights: List[WeightModule]
def __init__(self, weights: List[WeightModule],
layer_weights: Union[List[WeightModule], List[List[WeightModule]]]) -> None:
self.weights = weights
self.layer_weights = layer_weights
if len(self.layer_weights) == 0:
return
def set_weight_dtype(self, dtype: torch.dtype):
if self.layer_weights:
for weight in self.layer_weights:
weight.data_type = dtype
def get_layer_weight_info(self, layer_id: int, name: str) -> Optional[WeightModule]:
from collections import deque
queue = deque(self.layer_weights[layer_id])
while queue:
weight = queue.popleft()
if weight.name == name:
return weight
if isinstance(weight, CompositeWeight):
queue.extend(weight.sub_weights)
return None
def to_quant_weight_info(self, quant_algo: Any):
if quant_algo is None or not quant_algo.isQuant():
raise ValueError("quant_algo is None or not quant_algo.isQuant()")
weights = []
if self.weights:
for weight in self.weights:
weights.append(weight.create(weight, quant_algo))
layer_weights: Union[List[WeightModule], List[List[WeightModule]]] = [] if self.layer_weights else None
if self.layer_weights:
for weight in self.layer_weights:
if isinstance(weight, list):
layer_weight = []
for w in weight:
layer_weight.append(w.create(w, quant_algo))
layer_weights.append(layer_weight)
else:
layer_weights.append(weight.create(weight, quant_algo))
return ModelWeightInfo(weights, layer_weights)
class ModelDeployWeightInfo:
TRT_ENGINE_LAYER_WEIGHT_MAP = {
W.pre_ln_beta : 'transformer.layers.{i}.input_layernorm.bias',
W.pre_ln_gamma : 'transformer.layers.{i}.input_layernorm.weight',
W.attn_qkv_w : 'transformer.layers.{i}.attention.qkv.weight',
W.attn_qkv_b : 'transformer.layers.{i}.attention.qkv.bias',
W.attn_qkv_s : 'transformer.layers.{i}.attention.qkv.weights_scaling_factor',
W.attn_o_w : 'transformer.layers.{i}.attention.dense.weight',
W.attn_o_b : 'transformer.layers.{i}.attention.dense.bias',
W.attn_o_s : 'transformer.layers.{i}.attention.dense.weights_scaling_factor',
W.ffn_w3 : 'transformer.layers.{i}.mlp.fc.weight',
W.ffn_b3 : 'transformer.layers.{i}.mlp.fc.bias',
W.ffn_s3 : 'transformer.layers.{i}.mlp.fc.weights_scaling_factor',
W.ffn_w2 : 'transformer.layers.{i}.mlp.proj.weight',
W.ffn_b2 : 'transformer.layers.{i}.mlp.proj.bias',
W.ffn_s2 : 'transformer.layers.{i}.mlp.proj.weights_scaling_factor',
W.post_ln_gamma : 'transformer.layers.{i}.post_layernorm.weight',
W.post_ln_beta : 'transformer.layers.{i}.post_layernorm.bias',
}
TRT_ENGINE_LAYER_WEIGHT_MAP2 = {
W.pre_ln_beta : 'transformer.layers.{i}.input_layernorm.bias',
W.pre_ln_gamma : 'transformer.layers.{i}.input_layernorm.weight',
W.attn_qkv_w : 'transformer.layers.{i}.attention.qkv.weight',
W.attn_qkv_b : 'transformer.layers.{i}.attention.qkv.bias',
W.attn_qkv_s : 'transformer.layers.{i}.attention.qkv.weights_scaling_factor',
W.attn_o_w : 'transformer.layers.{i}.attention.dense.weight',
W.attn_o_b : 'transformer.layers.{i}.attention.dense.bias',
W.attn_o_s : 'transformer.layers.{i}.attention.dense.weights_scaling_factor',
W.ffn_w1 : 'transformer.layers.{i}.mlp.fc.weight',
W.ffn_b1 : 'transformer.layers.{i}.mlp.fc.bias',
W.ffn_s1 : 'transformer.layers.{i}.mlp.fc.weights_scaling_factor',
W.ffn_w2 : 'transformer.layers.{i}.mlp.proj.weight',
W.ffn_b2 : 'transformer.layers.{i}.mlp.proj.bias',
W.ffn_s2 : 'transformer.layers.{i}.mlp.proj.weights_scaling_factor',
W.ffn_w3 : 'transformer.layers.{i}.mlp.gate.weight',
W.ffn_b3 : 'transformer.layers.{i}.mlp.gate.bias',
W.ffn_s3 : 'transformer.layers.{i}.mlp.gate.weights_scaling_factor',
W.post_ln_gamma : 'transformer.layers.{i}.post_layernorm.weight',
W.post_ln_beta : 'transformer.layers.{i}.post_layernorm.bias',
}
def __init__(self, config: GptInitModelParameters, tp_size: int, tp_rank: int):
self.config = config
self._use_qk_norm = config.use_qk_norm
self._hidden_size = config.hidden_size
self._inter_size = config.inter_size
self._inter_padding_size = config.inter_padding_size
self._moe_inter_padding_size = config.moe_inter_padding_size
self._head_num = config.head_num
self._head_num_kv = config.head_num_kv
self.tp_size = tp_size
self.tp_rank = tp_rank
self.ep_size = config.ep_size
self.ep_rank = config.ep_rank
self.dp_size = config.dp_size
self.dp_rank = config.dp_rank
self.num_nodes: int = config.num_nodes
self.ffn_tp_rank = config.ffn_tp_rank
self.ffn_tp_size = config.ffn_tp_size
self._size_per_head = config.size_per_head
self.phy2log_ = config.phy2log
if self._head_num_kv == -1:
self._head_num_kv = self._head_num
self._quant_algo = config.quant_algo
self._num_layers = config.num_layers
self._layer_head_num = config.layer_head_num
self._layer_inter_padding_size = config.layer_inter_padding_size
self._has_prefix_encoder = False
self._is_sparse_head = config.is_sparse_head
self._layer_head_num = config.layer_head_num
self._src_quantization_bit = config.src_quantization_bit
self.tp_split_emb_and_lm_head = config.tp_split_emb_and_lm_head
self._is_gated_activation = config.gpt_init_params.isGatedActivation()
self.expert_num_ = config.gpt_init_params.expert_num
self.moe_n_group_ = config.moe_n_group
self.enable_eplb_ = config.enable_eplb
self.phy_exp_num_ = config.phy_exp_num
self.enable_merge_w13_ = config.enable_merge_w13
self.moe_k_ = config.gpt_init_params.moe_k
self.moe_layer_index_ = config.gpt_init_params.moe_layer_index
self.moe_style_ = config.gpt_init_params.moe_style
self._moe_inter_padding_size = config.moe_inter_padding_size
self.tie_word_embeddings = config.tie_word_embeddings
self.need_ffn_act_scale = config.need_ffn_act_scale
self.use_expert_attention = config.use_expert_attention
self.weight_style = WeightStyle.RTP_LLM_STYLE if config.is_ft_style_weight else WeightStyle.NONE
# for mla
self.kv_lora_rank = config.kv_lora_rank
self.nope_head_dim = config.nope_head_dim
self.rope_head_dim = config.rope_head_dim
self.v_head_dim = config.v_head_dim
self.routed_scaling_factor = config.routed_scaling_factor
self.is_ft_style_weight = config.is_ft_style_weight
# for vit sep
self.vit_separation = config.vit_separation
# for eplb
self.phy2log = config.phy2log
# for moe
self._use_stack_weight = False
@property
def support_lora(self):
return False
@property
def attn_config(self):
attn_config = AttnConfig(
hidden_size=self._hidden_size,
size_per_head=self._size_per_head,
head_num=self._head_num,
head_num_kv=self._head_num_kv
)
return attn_config
@property
def ffn_config(self):
ffn_config = FfnConfig(
is_gated_activation=self._is_gated_activation,
inter_padding_size=self._inter_padding_size,
is_moe=False
)
return ffn_config
def get_weight_info(self) -> ModelWeightInfo:
weight_info = self._get_weight_info()
use_fp32 = os.environ.get("USE_FLOAT32", None) is not None
if use_fp32:
weight_info = weight_info.set_weight_dtype(torch.float32)
if weight_info.layer_weights and not isinstance(weight_info.layer_weights[0], List):
layer_weights = []
for _ in range(self._num_layers):
layer_weights.append(weight_info.layer_weights)
weight_info.layer_weights = layer_weights
if self.weight_style != WeightStyle.NONE:
logging.info("fix weight style")
weight_info = self._fix_weight_style_layer_weight(weight_info)
if self.enable_merge_w13_:
logging.info("fix merge_w13")
weight_info = self._fix_merge_w1_w3(weight_info)
if self._quant_algo is not None and self._quant_algo.isQuant():
weight_info = weight_info.to_quant_weight_info(self._quant_algo)
if self.tie_word_embeddings:
logging.info("fix tie_word_embeddings")
weight_info = self._fix_tie_lm_head(weight_info)
if self._is_sparse_head:
logging.info("Skiping load empty weight for head_num == 0")
weight_info = self._process_sparse_weight(weight_info)
return weight_info
def _fix_weight_style_layer_weight(self, origin_weight_info: ModelWeightInfo):
global_weights = []
m1 = {
W.embedding : 'transformer.vocab_embedding.weight',
W.lm_head : 'lm_head.weight',
W.final_ln_gamma : 'transformer.ln_f.weight'
} if self.weight_style == WeightStyle.TRT_ENGINE else {}
def __update_weight_style(weight: WeightModule, name_map: Dict[str, str]):
if isinstance(weight , AtomicWeight):
weight.weight_style = self.weight_style
if weight.name in name_map:
if len(weight.weights) == 1:
weight.weights[0].name = name_map[weight.name]
elif weight.name in [W.attn_qkv_b, W.attn_qkv_w]:
weight.weights = [CkptWeightInfo(name_map[weight.name])]
weight.process_fun = identity
logging.error(f"{weight.name} have many weight, maybe cause bug {weight.weights}")
elif len(weight.weights) >= 2:
raise ValueError(f"{weight.name} should have only one or zero weight, {weight.weights}")
logging.info(f"update weight style for {weight.name}: {weight.weights[0].name}")
elif isinstance(weight, CompositeWeight):
weight.weight_style = self.weight_style
for _, sub_weight in weight.sub_weights.items():
__update_weight_style(sub_weight, name_map)
for _, weight in enumerate(origin_weight_info.weights):
__update_weight_style(weight, m1)
global_weights.append(weight)
origin_weight_info.weights = global_weights
layer_weights = []
for weights in origin_weight_info.layer_weights:
ffn_weight = [weight for weight in weights if weight.name == W.ffn]
assert len(ffn_weight) == 1
if ffn_weight[0].w1 is not None and self.weight_style == WeightStyle.TRT_ENGINE:
m2 = self.TRT_ENGINE_LAYER_WEIGHT_MAP2
elif self.weight_style == WeightStyle.TRT_ENGINE:
m2 = self.TRT_ENGINE_LAYER_WEIGHT_MAP
else:
m2 = {}
fix_weight = []
for weight in weights:
__update_weight_style(weight, m2)
fix_weight.append(weight)
layer_weights.append(fix_weight)
origin_weight_info.layer_weights = layer_weights
logging.info(f"fix weight style {origin_weight_info.layer_weights[0]}")
return origin_weight_info
def _fix_merge_w1_w3(self, origin_weight_info: ModelWeightInfo):
def __update_weight_config(weight: WeightModule):
if isinstance(weight , FfnWeight) or isinstance(weight, MoeWithSharedWeight):
weight.config.enable_merge_w13 = True
params = weight.extract_params(weight.__class__, weight, None)
return weight.__class__(**params)
else:
return weight
layer_weights = []
for weights in origin_weight_info.layer_weights:
fix_weight = []
for weight in weights:
fix_weight.append(__update_weight_config(weight))
layer_weights.append(fix_weight)
origin_weight_info.layer_weights = layer_weights
logging.info(f"fix weight config when need_merge_w13 {origin_weight_info.layer_weights[0]}")
return origin_weight_info
def _fix_tie_lm_head(self, origin_weight_info: ModelWeightInfo) -> ModelWeightInfo:
word_emb_idx = -1
word_emb = None
lm_head_idx = -1
lm_head = None
for idx, weight in enumerate(origin_weight_info.weights):
if weight.name == W.embedding:
word_emb_idx = idx
word_emb = weight
elif weight.name == W.lm_head:
lm_head = weight
lm_head_idx = idx
if not lm_head or not word_emb:
return origin_weight_info
assert len(lm_head.weights) == 1 and len(word_emb.weights) == 1
lm_head_ckpt_weigth_infos = [CkptWeightInfo(w.name, functools.partial(tolerate_failed, origin_func=w.merge_fun)) for w in lm_head.weights]
lm_head_ckpt_weigth_infos.extend([CkptWeightInfo(w.name, functools.partial(tolerate_failed, origin_func=w.merge_fun)) for w in word_emb.weights])
lm_head_merge_funcs = [lm_head.process_fun, word_emb.process_fun]
lm_head = AtomicWeight(W.lm_head, lm_head_ckpt_weigth_infos, functools.partial(choose_available, origin_func_list = lm_head_merge_funcs))
origin_weight_info.weights[lm_head_idx] = lm_head
return origin_weight_info
def _process_sparse_weight(self, origin_weight_info: ModelWeightInfo) -> ModelWeightInfo:
if not isinstance(origin_weight_info.layer_weights[0], list):
raise Exception("model weight use sparse config should be list(list())")
new_layer_weights = []
for i, layer_weight in enumerate(origin_weight_info.layer_weights):
if self._layer_head_num[i] == 0:
new_weights = [weight for weight in layer_weight if weight.name not in W.skip_weights_list]
else:
new_weights = layer_weight
new_layer_weights.append(new_weights)
return ModelWeightInfo(origin_weight_info.weights, new_layer_weights)
def _get_weight_info(self) -> ModelWeightInfo:
raise NotImplementedError()
def create_model_weight_info(self, database: BaseDatabase) -> ModelWeightInfo:
if isinstance(database, CkptDatabase):
self.process_meta_from_ckpt(database.PretrainFileList)
self.process_meta_from_ckpt(database.FinetuneFileList)
if not self.is_ft_style_weight:
return self.get_weight_info()
else:
raise Exception("Unknown database class")
def process_meta_from_ckpt(self, ckpt_metas: List[CkptFileInfo]):
if len(ckpt_metas) == 0:
return
if not self.is_ft_style_weight:
# call subclass process_meta
meta_dicts = [ckpt_file.get_metadata() for ckpt_file in ckpt_metas]
weight_keys = set(functools.reduce(lambda x,y:x+y, [list(meta.keys()) for meta in meta_dicts], []))
self._process_meta(meta_dicts, weight_keys)
def _process_meta(self, meta_dict, weight_keys):
pass
def _get_layer_start_end_id(self) -> Tuple[int, int]:
raise NotImplementedError()
@staticmethod
def _contains(keys: List[str], val: str):
for key in keys:
if val in key:
return True
return False
def create_load_config(self,
compute_dtype: torch.dtype,
database: BaseDatabase,
exported_device: Optional[Any] = None):
merge_lora = False
if not self.is_ft_style_weight:
merge_lora = database.has_lora() and bool(os.environ.get("MERGE_LORA", 1))
if database.has_lora() and not self.support_lora:
raise Exception(f"current weights_info: {self.__class__} not support lora, but database has lora")
load_config = LoadConfig(
database = database,
num_layers = self._num_layers,
hidden_size = self._hidden_size,
head_num = self._head_num,
head_num_kv = self._head_num_kv,
size_per_head = self._size_per_head,
use_stack_weight = self._use_stack_weight,
need_ffn_act_scale = self.need_ffn_act_scale,
inter_size = self._inter_size,
moe_layer_index = self.moe_layer_index_,
moe_n_group = self.moe_n_group_,
inter_padding_size = self._inter_padding_size,
moe_inter_padding_size = self._moe_inter_padding_size,
expert_num = self.expert_num_,
enable_eplb = self.enable_eplb_,
phy_exp_num = self.phy_exp_num_,
enable_merge_w13 = self.enable_merge_w13_,
tp_size = self.tp_size,
tp_rank = self.tp_rank,
ep_size = self.ep_size,
ep_rank = self.ep_rank,
dp_size = self.dp_size,
dp_rank = self.dp_rank,
num_nodes = self.num_nodes,
ffn_tp_rank = self.ffn_tp_rank,
ffn_tp_size = self.ffn_tp_size,
tp_split_emb_and_lm_head = self.tp_split_emb_and_lm_head,
merge_lora = merge_lora,
vit_separation = self.vit_separation,
compute_dtype = compute_dtype,
quant_algo = self._quant_algo,
bit = self._quant_algo.getWeightBits(),
is_ft_style_weight = self.is_ft_style_weight,
phy2log=self.phy2log_,
exported_device = exported_device
)
return load_config
class ModelWeights:
def __init__(self, num_layers: int, device: str, dtype: torch.dtype):
self.device = device
self.weights: List[Dict[str, torch.Tensor]] = []
self.global_weights: Dict[str, torch.Tensor] = {}
self._dtype = dtype
self.is_ft_style_weight: bool = False
for _ in range(num_layers):
self.weights.append({})
def set_layer_weight(self, layer_id: int, name: str, tensor: torch.Tensor):
self.weights[layer_id][name] = tensor
gc.collect()
def set_global_weight(self, name: str, tensor: torch.Tensor):
self.global_weights[name] = tensor
def get_global_weight(self, name: str):
return self.global_weights.get(name, None)
def steal_global_weight(self, name: str):
if name not in self.global_weights:
return None
tensor = self.global_weights[name]
del self.global_weights[name]
return tensor
@property
def dtype(self):
return self._dtype
@staticmethod
def layer_weight_prefix(tp_rank:int, dp_rank: int, ep_rank: int):
return f"rank_{tp_rank:02d}_{dp_rank:02d}_{ep_rank:02d}.layers."
@staticmethod
def global_weight_prefix(tp_rank:int, dp_rank: int, ep_rank: int):
return f"rank_{tp_rank:02d}_{dp_rank:02d}_{ep_rank:02d}.global."