maga_transformer/model_loader/weight_module.py (430 lines of code) (raw):
import functools
import logging
import torch
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Type, Union
import inspect
import weakref
from maga_transformer.model_loader.load_config import LoadConfig
from maga_transformer.utils.database import BaseDatabase
from maga_transformer.utils.model_weight import CkptWeightInfo, W, WeightStyle, identity, sp_0, sp_head_lora, sp_id, sp_neg1
import traceback
class WeightModule(ABC):
_registry = weakref.WeakValueDictionary()
_cache = weakref.WeakKeyDictionary()
lora_base_name = "base_model.model.{}.{}.weight"
lora_A_suffix = 'lora_A'
lora_B_suffix = 'lora_B'
def __init__(self,
name: str,
lora_a_process_func: Optional[Callable]=None,
lora_b_process_func: Optional[Callable]=None,
lora_a_split_func: Optional[Callable]=None,
lora_b_split_func: Optional[Callable]=None,
**kwargs: Any):
self.name = name
self.weight_style = kwargs.pop("weight_style", WeightStyle.NONE)
self.lora_a_process_func: Optional[Callable] = lora_a_process_func
self.lora_b_process_func: Optional[Callable] = lora_b_process_func
self.lora_a_split_func: Optional[Callable] = lora_a_split_func
self.lora_b_split_func: Optional[Callable] = lora_b_split_func
self.lora_a: Optional['WeightModule'] = None
self.lora_b: Optional['WeightModule'] = None
self.is_lora = kwargs.pop("is_lora", False)
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)
cls._registry[cls.__name__] = cls
@property
def lora_a_name(self):
return f"{self.name}.{self.lora_A_suffix}"
@property
def lora_b_name(self):
return f"{self.name}.{self.lora_B_suffix}"
@classmethod
def create(
cls,
weight_info: "WeightModule",
quant_algo: Optional[Any] = None
) -> "WeightModule":
if quant_algo is None or not quant_algo.isQuant():
return weight_info
if isinstance(weight_info, QuantWeight):
return weight_info
if isinstance(weight_info, AtomicWeight):
valid_classes = [
c for _, c in cls._registry.items()
if c.support(quant_algo, weight_info)
]
if not valid_classes:
return weight_info
if len(valid_classes) > 1:
raise ValueError(f"{weight_info.name} fit too many valid_classes:{valid_classes} with quant={quant_algo} for weight: {weight_info}")
target_cls = valid_classes[0]
params = cls.extract_params(target_cls, weight_info, quant_algo)
return target_cls(**params)
elif isinstance(weight_info, CompositeWeight):
target_cls = weight_info.__class__
params = target_cls.extract_params(target_cls, weight_info, quant_algo)
return target_cls(**params)
else:
raise ValueError(f"Invalid weight_info type: {type(weight_info)}")
@classmethod
def from_params(cls, params):
return cls(**params)
@classmethod
def extract_params(
cls,
target_cls: Type["WeightModule"],
weight_info: "WeightModule",
quant_algo: Any
) -> Dict[str, Any]:
params = {}
signature = inspect.signature(target_cls.__init__)
need_var_key = False
for param in list(signature.parameters.values())[1:]: # Skip self
if param.kind == inspect.Parameter.VAR_KEYWORD or param.kind == inspect.Parameter.VAR_POSITIONAL:
need_var_key = True
continue
if param.name == "quant_algo":
params[param.name] = quant_algo
continue
if param.name == 'src_weight_info':
params['src_weight_info'] = weight_info
continue
if hasattr(weight_info, param.name):
value = getattr(weight_info, param.name)
# 递归创建子权重
if param.name == "sub_weights" and isinstance(value, dict):
value = [
cls.create(v, quant_algo)
for _, v in value.items()
]
params[param.name] = value
elif param.default != inspect.Parameter.empty:
params[param.name] = param.default
else:
raise ValueError(f"target_cls: {target_cls} Missing required parameter: {param.name}")
if need_var_key:
for k, v in weight_info.__dict__.items():
if isinstance(v, WeightModule):
continue
if k in params:
continue
params[k] = v
return params
@classmethod
@abstractmethod
def support(cls, quant_algo: Any, src_weight_info: 'WeightModule') -> bool:
pass
@torch.inference_mode()
def load(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig):
raw_tensors = self._load_raw_tensor(database, layer_id, device, load_config)
logging.debug(f"load weight: {self.name} layer_id: {layer_id}, res:{raw_tensors}")
if load_config.merge_lora:
logging.debug(f"merge lora weight: {self.name} layer_id: {layer_id}")
merged_tensors = self._merge_lora(raw_tensors, database, layer_id, load_config)
else:
merged_tensors = raw_tensors
split_tensors = self._split(merged_tensors, load_config)
logging.debug(f"split weight: {self.name} layer_id: {layer_id}, res:{split_tensors}")
processed_tensors = self._postprocess(split_tensors, device, load_config)
logging.debug(f"postprocess weight: {self.name} layer_id: {layer_id}, res:{processed_tensors}")
flat_res = {}
def __extract_tensor(tensors):
for k,v in tensors.items():
if isinstance(v, dict):
__extract_tensor(v)
else:
flat_res.update({k: v.to(device)})
__extract_tensor(processed_tensors)
shape_info = {k: (v.shape, v.dtype) for k, v in flat_res.items()}
logging.debug(f"extract weight: {self.name} layer_id: {layer_id}, res:{shape_info}")
return flat_res
@torch.inference_mode()
def load_lora(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig, lora_name: str):
try:
raw_loras = self._load_raw_lora(database, layer_id, device, load_config, lora_name)
except Exception as e:
logging.warning(f"load layer: {layer_id} lora tensor {self.lora_a} or {self.lora_b} failed: traceback: {traceback.format_exc()}")
return {}
if raw_loras is None:
return {}
if load_config.tp_size <= 1 and load_config.dp_size <= 1 and load_config.ep_size <= 1 :
res = raw_loras
else:
res = self._split_lora(raw_loras, load_config)
flat_res = {}
def __extract_tensor(tensors):
for k,v in tensors.items():
if isinstance(v, dict):
__extract_tensor(v)
elif v is not None:
flat_res.update({k: v.contiguous().clone().to(device)})
__extract_tensor(res)
return flat_res
@abstractmethod
def _load_raw_tensor(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig):
pass
@abstractmethod
def _split(self, tensor: torch.Tensor, load_config: LoadConfig):
pass
@abstractmethod
def _postprocess(self, tensor: torch.Tensor, device: str, load_config: LoadConfig):
return tensor
@abstractmethod
def _merge_lora(self, tensor: Dict[str, torch.Tensor], database: BaseDatabase, layer_id: Optional[int], load_config: LoadConfig):
pass
@abstractmethod
def _load_raw_lora(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig, lora_name: str):
pass
@abstractmethod
def _split_lora(self, tensor: Dict[str, torch.Tensor], load_config: LoadConfig):
pass
class AtomicWeight(WeightModule):
weights: List[CkptWeightInfo]
process_fun: Callable[[List[torch.Tensor]], torch.Tensor]
data_type: Optional[torch.dtype] = None
split_func = None
"""原子权重(不可分割的单个权重)"""
def __init__(
self,
name: str,
weights: List[CkptWeightInfo],
process_fun: Callable[[List[torch.Tensor]], torch.Tensor] = identity,
data_type: Optional[torch.dtype] = None,
**kwargs
) -> None:
self.name = name
self.weights = weights
self.process_fun = process_fun
self.data_type = data_type
super().__init__(name=name, **kwargs)
def create_from(self, *args: Any, **kwargs: Any) -> 'AtomicWeight':
return self.__class__(*args, **kwargs)
@property
def need_transpose(self) -> bool:
if isinstance(self.process_fun, functools.partial) and self.process_fun.func.__name__ in ['transpose_pad', 'transpose']:
return True
else:
return False
def _load_raw_tensor(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig):
before_merge_tensors = []
convert_type = self.data_type if self.data_type is not None else load_config.compute_dtype
for ckpt_weight in self.weights:
name = ckpt_weight.tensor_name(layer_id)
try:
before_merge_tensors.append(ckpt_weight.merge_fun([x.to(device) for x in database.load_tensor(name, convert_type)]))
except Exception as e:
logging.error(f"加载 {self.name}: {name} 失败,完整堆栈:\n{traceback.format_exc()}")
raise e
after_merge_tensor = self.process_fun(before_merge_tensors).to(device).to(convert_type)
return {self.name: after_merge_tensor}
def lora_tensor_name(self, layer_id: Optional[int], name: str):
if layer_id is not None:
return name.format(i=str(layer_id), i_1=str(layer_id + 1))
return name
def _load_raw_lora(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig, lora_name: str):
if self.lora_a_process_func is None or self.lora_b_process_func is None:
return {}
a_res = self._load_lora_a(database, layer_id, device, load_config, lora_name)
b_res = self._load_lora_b(database, layer_id, device, load_config, lora_name)
a_res.update(b_res)
return a_res
def _split_lora(self, tensor: Dict[str, torch.Tensor], load_config: LoadConfig) -> Dict[str, torch.Tensor]:
if self.lora_a_split_func is None or self.lora_b_split_func is None or not tensor:
return tensor
lora_a_name: str = self.lora_a_name
lora_b_name: str = self.lora_b_name
return {
lora_a_name: self.__split_tensor(self.lora_a_split_func, tensor.get(lora_a_name), load_config),
lora_b_name: self.__split_tensor(self.lora_b_split_func, tensor.get(lora_b_name), load_config)
}
def __split_tensor(self, split_func: Callable, tensor: torch.Tensor, load_config: LoadConfig) -> torch:
return split_func(t=tensor,
tp=load_config.tp_size,
tp_rank=load_config.tp_rank,
ep=load_config.ep_size,
ep_rank=load_config.ep_rank,
dp=load_config.dp_size,
dp_rank=load_config.dp_rank,
ffn_tp_rank=load_config.ffn_tp_rank,
ffn_tp_size=load_config.ffn_tp_size,
hidden_size=load_config.hidden_size,
head_num=load_config.head_num,
head_num_kv=load_config.head_num_kv,
size_per_head=load_config.size_per_head,
use_stack_weight=load_config.use_stack_weight,
bits=load_config.bit
)
def _load_lora_a(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig, lora_name: str):
assert self.lora_a_process_func is not None
before_merge_tensors = []
for ckpt_weight in self.weights:
ckpt_name = self.lora_base_name.format(ckpt_weight.name[:-len(".weight")], self.lora_A_suffix)
tensor_name = self.lora_tensor_name(layer_id, ckpt_name)
try:
before_merge_tensors.append(ckpt_weight.merge_fun([x for x in database.load_lora_tensor(lora_name, tensor_name)]))
except:
logging.warning(f"load {self.name} lora A failed: {tensor_name}, {traceback.format_exc()}")
return {}
after_merge_tensor = self.lora_a_process_func(before_merge_tensors)
return {self.lora_a_name : after_merge_tensor}
def _load_lora_b(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig, lora_name: str):
assert self.lora_b_process_func is not None
before_merge_tensors = []
for ckpt_weight in self.weights:
ckpt_name = self.lora_base_name.format(ckpt_weight.name[:-len(".weight")], self.lora_B_suffix)
tensor_name = self.lora_tensor_name(layer_id, ckpt_name)
try:
before_merge_tensors.append(ckpt_weight.merge_fun([x for x in database.load_lora_tensor(lora_name, tensor_name)]))
except:
logging.warning(f"load {self.name} lora B failed: {tensor_name}, {traceback.format_exc()}")
return {}
after_merge_tensor = self.lora_b_process_func(before_merge_tensors)
return {self.lora_b_name : after_merge_tensor}
def _merge_lora(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], database: BaseDatabase, layer_id: Optional[int], load_config: LoadConfig, lora_name:Optional[str] = None):
if self.lora_a_process_func is None or self.lora_b_process_func is None:
return tensor
lora_name = database.get_first_lora_name() if lora_name is None else lora_name
assert lora_name is not None
if lora_name is None:
raise Exception(f"invalid empty lora name")
try:
raw_loras = self._load_raw_lora(database, layer_id, device=load_config.exported_device, load_config=load_config, lora_name=lora_name)
lora_a_tensor = raw_loras[self.lora_a_name]
lora_b_tensor = raw_loras[self.lora_b_name]
except Exception as e:
logging.warning(f"load layer: {layer_id} lora tensor {self.lora_a} or {self.lora_b} failed: traceback: {traceback.format_exc()}")
return tensor
if lora_a_tensor is None or lora_b_tensor is None:
return tensor
raw_tensor = tensor if isinstance(tensor, torch.Tensor) else tensor[self.name]
scale = database.get_lora_config(lora_name).get_scale()
# "addmm_impl_cpu_" not implemented for 'Half'
if lora_b_tensor.dim() == 3 and lora_a_tensor.dim() == 2:
lora_b_tensor = lora_b_tensor.reshape(lora_b_tensor.shape[0], lora_b_tensor.shape[1] * lora_b_tensor.shape[2])
merge_tensor = (lora_a_tensor.type(torch.float32) @ lora_b_tensor.type(torch.float32) * scale).type(raw_tensor.dtype).to(raw_tensor.device)
# moe
elif lora_b_tensor.dim() == 3 and lora_a_tensor.dim() == 3:
merge_tensor = torch.bmm(lora_a_tensor.type(torch.float32), lora_b_tensor.type(torch.float32) * scale).type(raw_tensor.dtype).to(raw_tensor.device)
else:
merge_tensor = (lora_a_tensor.type(torch.float32) @ lora_b_tensor.type(torch.float32) * scale).type(raw_tensor.dtype).to(raw_tensor.device)
shape = raw_tensor.shape
raw_tensor = raw_tensor.reshape(raw_tensor.nelement()) + merge_tensor.reshape(raw_tensor.nelement())
raw_tensor = raw_tensor.reshape(shape)
del lora_a_tensor
del lora_b_tensor
return {self.name : raw_tensor}
def _split(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], load_config: LoadConfig):
raw_tensor = tensor if isinstance(tensor, torch.Tensor) else tensor[self.name]
if load_config.tp_size <= 1 and load_config.dp_size <= 1 and load_config.ep_size <= 1 :
return {self.name : raw_tensor}
tp_split_emb_and_lm_head = load_config.tp_split_emb_and_lm_head
if (not tp_split_emb_and_lm_head and
self.name in [W.lm_head, W.lm_head_b, W.embedding, W.positional_embedding, W.token_type_embedding]):
return {self.name : raw_tensor}
split_func = self._get_split_func()
ts = self.__split_tensor(split_func, raw_tensor, load_config).contiguous().clone()
return {self.name: ts}
def _postprocess(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], device:str, load_config: LoadConfig):
raw_tensor = tensor.get(self.name) if isinstance(tensor, dict) else tensor
return {self.name: load_config.exported_device.maybe_rewrite_weight_by_key(self.name, raw_tensor)}
def _get_split_func(self):
return W.gpt_style_tp_strategy[self.name]
def get_components(self):
return [self]
@classmethod
def support(cls, quant_algo: Any, src_weight_info: WeightModule) -> bool:
return quant_algo is None or not quant_algo.isQuant()
def get_ckpt_tensor_names(self) -> List[str]:
if not bool(self.weights):
return []
return [ckpt.name for ckpt in self.weights]
def __str__(self) -> str:
return f"AtomicWeight[{self.name}]-{self.weight_style}-{self.weights}"
def __repr__(self) -> str:
return self.__str__()
class QuantWeight(WeightModule):
def __init__(self, name: str, quant_algo, *args, **kwargs):
super().__init__(name)
self.quant_algo = quant_algo
class MMAtomicWeight(AtomicWeight):
def __init__(
self,
name: str,
weights: List[CkptWeightInfo],
process_fun: Callable[[List[torch.Tensor]], torch.Tensor] = identity,
data_type: Optional[torch.dtype] = None,
split_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
**kwargs
) -> None:
super().__init__(name, weights, process_fun, data_type, **kwargs)
self.split_func = split_func
def _get_split_func(self):
return self.split_func
class CompositeWeight(WeightModule):
"""复合权重组件(如MoE、FFN)"""
def __init__(self, sub_weights: Dict[str, WeightModule], *args, **kwargs):
super().__init__(*args, **kwargs)
self.sub_weights = self._init_sub_weights(sub_weights) if isinstance(sub_weights, list) else sub_weights
def get_components(self):
res = []
for sub_weight in self.sub_weights.values():
res.extend(sub_weight.get_components())
return res
def _init_sub_weights(self, sub_weights: List[WeightModule]):
inited_sub_weights = {}
for sub_weight in sub_weights:
inited_sub_weights.update({sub_weight.name: sub_weight})
return inited_sub_weights
def __str__(self) -> str:
return f"{self.__class__}[{self.name}]{self.sub_weights}"
def __repr__(self) -> str:
return self.__str__()
def _load_raw_tensor(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig):
raw_tensors = {}
for name, sub_weight in self.sub_weights.items():
sub_tensors = sub_weight._load_raw_tensor(database, layer_id, device, load_config)
if isinstance(sub_weight, AtomicWeight) and isinstance(sub_tensors, dict):
raw_tensors.update(sub_tensors)
else:
raw_tensors.update({name:sub_tensors})
return raw_tensors
def _merge_lora(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], database: BaseDatabase, layer_id: Optional[int], load_config: LoadConfig):
merged_tensors = {}
for name, sub_weight in self.sub_weights.items():
sub_tensors = tensor.get(name)
sub_tensors = sub_weight._merge_lora(sub_tensors, database, layer_id, load_config)
if isinstance(sub_weight, AtomicWeight) and isinstance(sub_tensors, dict):
merged_tensors.update(sub_tensors)
else:
merged_tensors.update({name:sub_tensors})
return merged_tensors
def _load_raw_lora(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig, lora_name: str):
raw_tensors = {}
for name, sub_weight in self.sub_weights.items():
sub_tensors = sub_weight._load_raw_lora(database, layer_id, device, load_config, lora_name=lora_name)
raw_tensors.update({name:sub_tensors})
return raw_tensors
def _split_lora(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], load_config: LoadConfig):
split_tensors = {}
for name, sub_weight in self.sub_weights.items():
sub_tensors = tensor.get(name)
sub_tensors = sub_weight._split_lora(sub_tensors, load_config)
split_tensors.update({name:sub_tensors})
return split_tensors
def _split(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], load_config: LoadConfig):
split_tensors = {}
for name, sub_weight in self.sub_weights.items():
sub_tensors = tensor.get(name)
sub_tensors = sub_weight._split(sub_tensors, load_config)
if isinstance(sub_weight, AtomicWeight) and isinstance(sub_tensors, dict):
split_tensors.update(sub_tensors)
else:
split_tensors.update({name:sub_tensors})
return split_tensors
def _postprocess(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], device: str, load_config: LoadConfig) -> torch.Tensor:
processed_tensors = {}
for name, sub_weight in self.sub_weights.items():
sub_tensors = tensor.get(name)
sub_tensors = sub_weight._postprocess(sub_tensors, device, load_config)
if isinstance(sub_weight, AtomicWeight) and isinstance(sub_tensors, dict):
processed_tensors.update(sub_tensors)
else:
processed_tensors.update({name:sub_tensors})
return processed_tensors