maga_transformer/model_loader/ffn_weight.py (230 lines of code) (raw):
import functools
import logging
import traceback
import torch
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Optional, Union
from maga_transformer.utils.util import check_with_info
from maga_transformer.utils.model_weight import (W, CkptWeightInfo, identity)
from maga_transformer.model_loader.load_config import LoadConfig
from maga_transformer.utils.database import BaseDatabase
from maga_transformer.model_loader.weight_module import QuantWeight, WeightModule, AtomicWeight, CompositeWeight
class FfnConfig(BaseModel):
is_gated_activation: bool=False
inter_padding_size: int=-1
is_moe: bool=False
need_post_ln: bool=False
enable_merge_w13: bool=False
need_ffn_act_scale: bool=False
class FfnAtomicWeight(AtomicWeight):
def __init__(self, name: str, weights: List[CkptWeightInfo], process_fun: Callable[[List[torch.Tensor]], torch.Tensor]=identity, data_type: Optional[torch.dtype]=None, config: FfnConfig = None, *args: Any, **kwargs: Any):
self.config = config
super().__init__(name, weights, process_fun, data_type, *args, **kwargs)
@property
def need_padding(self) -> bool:
if isinstance(self.process_fun, functools.partial) and self.process_fun.func.__name__ in ['transpose_pad', 'pad']:
return True
else:
return False
@property
def pad_dim(self) -> Optional[int]:
if not self.need_padding:
return None
return self.process_fun.keywords['dim']
def w13_func_wrap(ts: List[torch.Tensor], origin_w1, origin_w3):
w1_size = len(origin_w1.weights)
w3_size = len(origin_w3.weights)
assert len(ts) == w1_size + w3_size
w1 = origin_w1.process_fun(ts[:w1_size])
w3 = origin_w3.process_fun(ts[w1_size:])
return torch.concat([w1, w3], dim=-1).contiguous()
def w13_lora_a_func_wrap(ts: torch.Tensor, origin_w1: FfnAtomicWeight, origin_w3: FfnAtomicWeight):
assert origin_w1.lora_a_process_func and origin_w3.lora_a_process_func
w1, w3 = torch.chunk(ts, 2, dim = -1)
w1 = origin_w1.lora_a_process_func(w1)
w3 = origin_w3.lora_a_process_func(w3)
return torch.concat([w1, w3], dim=-1).contiguous()
def w13_lora_b_func_wrap(ts: torch.Tensor, origin_w1: FfnAtomicWeight, origin_w3: FfnAtomicWeight):
assert origin_w1.lora_b_process_func and origin_w3.lora_b_process_func
w1, w3 = torch.chunk(ts, 2, dim=-1)
w1 = origin_w1.lora_b_process_func(w1)
w3 = origin_w3.lora_b_process_func(w3)
return torch.concat([w1, w3], dim=-1).contiguous()
def w13_lora_a_split_func_wrap(ts: torch.Tensor, origin_w1: FfnAtomicWeight, origin_w3: FfnAtomicWeight):
assert origin_w1.lora_a_split_func and origin_w3.lora_a_split_func
w1, w3 = torch.chunk(ts, 2, dim=-1)
w1 = origin_w1.lora_a_split_func(w1)
w3 = origin_w3.lora_a_split_func(w3)
return torch.concat([w1, w3], dim=-1).contiguous()
def w13_lora_b_split_func_wrap(ts: torch.Tensor, origin_w1: FfnAtomicWeight, origin_w3: FfnAtomicWeight):
assert origin_w1.lora_b_split_func and origin_w3.lora_b_split_func
w1, w3 = torch.chunk(ts, 2, dim=-1)
w1 = origin_w1.lora_b_split_func(w1)
w3 = origin_w3.lora_b_split_func(w3)
return torch.concat([w1, w3], dim=-1).contiguous()
def fix_merge_w13(sub_weight_dict: Dict[str, FfnAtomicWeight]):
origin_w1 = sub_weight_dict[W.ffn_w1]
origin_w3 = sub_weight_dict[W.ffn_w3]
w_list = origin_w1.weights + origin_w3.weights
lora_a_process_func=functools.partial(w13_lora_a_func_wrap, origin_w1=origin_w1, origin_w3=origin_w3) if origin_w1.lora_a_process_func else None
lora_b_process_func=functools.partial(w13_lora_b_func_wrap, origin_w1=origin_w1, origin_w3=origin_w3) if origin_w1.lora_b_process_func else None
lora_a_split_func=functools.partial(w13_lora_a_split_func_wrap, origin_w1=origin_w1, origin_w3=origin_w3) if origin_w1.lora_a_split_func else None
lora_b_split_func=functools.partial(w13_lora_b_split_func_wrap, origin_w1=origin_w1, origin_w3=origin_w3) if origin_w1.lora_b_split_func else None
w13 = FfnAtomicWeight(name=W.ffn_w13,
weights=w_list,
process_fun=functools.partial(w13_func_wrap, origin_w1=origin_w1, origin_w3=origin_w3),
lora_a_process_func=lora_a_process_func,
lora_b_process_func=lora_b_process_func,
lora_a_split_func=lora_a_split_func,
lora_b_split_func=lora_b_split_func,
data_type=origin_w1.data_type, config=origin_w1.config)
sub_weight_dict.pop(W.ffn_w1)
sub_weight_dict.pop(W.ffn_w3)
sub_weight_dict[W.ffn_w13] = w13
return sub_weight_dict
def fix_merge_b13(sub_weight_dict: Dict[str, FfnAtomicWeight]):
origin_b1 = sub_weight_dict[W.ffn_b1]
origin_b3 = sub_weight_dict[W.ffn_b3]
w_list = origin_b1.weights + origin_b3.weights
lora_a_process_func=functools.partial(w13_lora_a_func_wrap, origin_w1=origin_b1, origin_w3=origin_b3) if origin_b1.lora_a_process_func else None
lora_b_process_func=functools.partial(w13_lora_b_func_wrap, origin_w1=origin_b1, origin_w3=origin_b3) if origin_b1.lora_b_process_func else None
lora_a_split_func=functools.partial(w13_lora_a_split_func_wrap, origin_w1=origin_b1, origin_w3=origin_b3) if origin_b1.lora_a_split_func else None
lora_b_split_func=functools.partial(w13_lora_b_split_func_wrap, origin_w1=origin_b1, origin_w3=origin_b3) if origin_b1.lora_b_split_func else None
b13 = FfnAtomicWeight(name=W.ffn_w13,
weights=w_list,
process_fun=functools.partial(FfnWeight.__w13_func_wrap, origin_w1=origin_b1, origin_w3=origin_b3),
lora_a_process_func=lora_a_process_func,
lora_b_process_func=lora_b_process_func,
lora_a_split_func=lora_a_split_func,
lora_b_split_func=lora_b_split_func,
data_type=origin_b1.data_type, config=origin_b1.config)
sub_weight_dict.pop(W.ffn_b1)
sub_weight_dict.pop(W.ffn_b3)
sub_weight_dict[W.ffn_b13] = b13
return sub_weight_dict
class FfnWeight(CompositeWeight):
def __init__(self, sub_weights: Union[Dict[str, FfnAtomicWeight], List[Union[FfnAtomicWeight, AtomicWeight]]], config: FfnConfig, *args: Any, **kwargs: Any):
self.name = W.ffn
sub_weight_dict = {sub_weight.name: sub_weight for sub_weight in sub_weights}
self.config = config
if self.config.enable_merge_w13 and (W.ffn_w1 in sub_weight_dict and W.ffn_w3 in sub_weight_dict):
self.origin_w1 = sub_weight_dict[W.ffn_w1]
self.origin_w3 = sub_weight_dict[W.ffn_w3]
sub_weight_dict = fix_merge_w13(sub_weight_dict)
if self.config.enable_merge_w13 and (W.ffn_b1 in sub_weight_dict and W.ffn_b3 in sub_weight_dict):
self.origin_b1 = sub_weight_dict[W.ffn_b1]
self.origin_b3 = sub_weight_dict[W.ffn_b3]
sub_weight_dict = fix_merge_b13(sub_weight_dict)
kwargs['name'] = W.ffn
super().__init__(sub_weight_dict, *args, **kwargs)
self.w1 = self.sub_weights.get(W.ffn_w1)
self.w2 = self.sub_weights.get(W.ffn_w2)
self.w3 = self.sub_weights.get(W.ffn_w3)
self.w13 = self.sub_weights.get(W.ffn_w13)
self.b1 = self.sub_weights.get(W.ffn_b1)
self.b2 = self.sub_weights.get(W.ffn_b2)
self.b3 = self.sub_weights.get(W.ffn_b3)
self.b13 = self.sub_weights.get(W.ffn_b13)
@classmethod
def support(cls, quant_algo: Any, src_weight_info: WeightModule) -> bool:
return False
def _split(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], load_config: LoadConfig):
if load_config.tp_size <= 1 and load_config.dp_size <= 1 and load_config.ep_size <= 1 :
if self.name not in [W.moe_w1, W.moe_w2]:
return tensor
return super()._split(tensor, load_config)
class MoeConfig(BaseModel):
is_moe: bool = True
expert_num: int = -1
inter_padding_size: int = -1
routed_scaling_factor: float = 1.0
weight_stack: bool = False
enable_merge_w13: bool = False
class MoeAtomicWeight(AtomicWeight):
def __init__(self, name: str,
weights: List[CkptWeightInfo],
process_fun: Callable[[List[torch.Tensor]], torch.Tensor]=identity,
data_type: Optional[torch.dtype]=None,
config: MoeConfig = None,
*args:Any, **kwargs: Any):
self.config = config
super().__init__(name, weights, process_fun, data_type, *args, **kwargs)
def _load_raw_tensor(self, database: BaseDatabase, layer_id: Optional[int], device: str, load_config: LoadConfig):
if self.config.weight_stack:
return super()._load_raw_tensor(database, layer_id, device, load_config)
# weight should be expand by experts
before_merge_tensors = []
convert_type = self.data_type if self.data_type is not None else load_config.compute_dtype
for ckpt_weight in self.weights:
selected_experts = load_config.get_selected_experts(layer_id, self.config.expert_num)
for expert_id in selected_experts:
name = ckpt_weight.name.format(i=str(layer_id), i_1=str(layer_id + 1), expert_id=str(expert_id))
logging.debug(f"tensor name: {name}")
try:
before_merge_tensors.append(ckpt_weight.merge_fun([x.to(device) for x in database.load_tensor(name, convert_type)]))
except Exception as e:
logging.error(f"加载 {name} 失败,完整堆栈:\n{traceback.format_exc()}")
raise e
after_merge_tensor = self.process_fun(before_merge_tensors).to(convert_type)
logging.debug("load weight :%s, %s ", self.name, after_merge_tensor.shape)
return {self.name: after_merge_tensor}
class MoeWeight(CompositeWeight):
def __init__(self, sub_weights: List[MoeAtomicWeight], config: MoeConfig, **kwargs: Any):
self.config = config
# check all is MoeAtomicWeight
assert all(isinstance(sub_weight, MoeAtomicWeight) or isinstance(sub_weight, QuantWeight) for sub_weight in sub_weights)
kwargs['name'] = W.moe
super().__init__(sub_weights, **kwargs)
self.moe_w1 = self.sub_weights[W.moe_w1]
self.moe_w2 = self.sub_weights[W.moe_w2]
self.moe_gate = self.sub_weights[W.moe_gate]
@classmethod
def support(cls, quant_algo: Any, src_weight_info: WeightModule) -> bool:
return False
class SharedMoeConfig(FfnConfig, MoeConfig):
pass
class MoeWithSharedWeight(CompositeWeight):
def __init__(self, sub_weights: List[Union[FfnAtomicWeight, MoeAtomicWeight]], config: SharedMoeConfig, **kwargs: Any):
self.config = config
check_with_info(all(isinstance(sub_weight, MoeAtomicWeight) or isinstance(sub_weight, FfnAtomicWeight) or isinstance(sub_weight, QuantWeight) for sub_weight in sub_weights),
f"{[sub_weight.__class__ for sub_weight in sub_weights]}, {sub_weights}")
kwargs['name'] = W.moe
sub_weight_dict = {sub_weight.name: sub_weight for sub_weight in sub_weights}
if self.config.enable_merge_w13 and (W.ffn_w1 in sub_weight_dict and W.ffn_w3 in sub_weight_dict):
self.origin_w1 = sub_weight_dict[W.ffn_w1]
self.origin_w3 = sub_weight_dict[W.ffn_w3]
sub_weight_dict = fix_merge_w13(sub_weight_dict)
if self.config.enable_merge_w13 and (W.ffn_b1 in sub_weight_dict and W.ffn_b3 in sub_weight_dict):
self.origin_b1 = sub_weight_dict[W.ffn_b1]
self.origin_b3 = sub_weight_dict[W.ffn_b3]
sub_weight_dict = fix_merge_b13(sub_weight_dict)
super().__init__(sub_weight_dict, **kwargs)
self.moe_w1 = self.sub_weights.get(W.moe_w1)
self.moe_w2 = self.sub_weights.get(W.moe_w2)
self.moe_gate = self.sub_weights.get(W.moe_gate)
self.ffn_w1 = self.sub_weights.get(W.ffn_w1)
self.ffn_w2 = self.sub_weights.get(W.ffn_w2)
self.ffn_w3 = self.sub_weights.get(W.ffn_w3)
self.w13 = self.sub_weights.get(W.ffn_w13)
self.ffn_b1 = self.sub_weights.get(W.ffn_b1)
self.ffn_b2 = self.sub_weights.get(W.ffn_b2)
self.ffn_b3 = self.sub_weights.get(W.ffn_b3)
self.ffn_b13 = self.sub_weights.get(W.ffn_b13)
self.shared_expert_gate = self.sub_weights.get(W.shared_expert_gate)
@classmethod
def support(cls, quant_algo: Any, src_weight_info: WeightModule) -> bool:
return False
def _shuff_moe_weight(self, name:str, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], load_config: LoadConfig):
w = tensor.get(name)
if isinstance(w, torch.Tensor):
w = load_config.exported_device.shuffle_moe_weight(w, load_config.compute_dtype, name)
tensor[name] = w
elif isinstance(w, dict):
self._shuff_moe_weight(name, w, load_config)
else:
raise ValueError("unsupported type")
def _split(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], load_config: LoadConfig):
res = super()._split(tensor, load_config)
return res
def _postprocess(self, tensor: torch.Tensor, device: str, load_config: LoadConfig):
self._shuff_moe_weight(W.moe_w1, tensor, load_config)
self._shuff_moe_weight(W.moe_w2, tensor, load_config)
return super()._postprocess(tensor, device, load_config)