in maga_transformer/model_loader/group_wise_quant_weight.py [0:0]
def get_ffn_quant_weight_info(src_weight: Union[FfnAtomicWeight, MoeAtomicWeight], quant_algo: Any) -> List[Union[FfnAtomicWeight, MoeAtomicWeight]]:
weights = src_weight.weights
ffn_w_name = src_weight.name
assert weights[0].name.endswith(W_SUFFIX)
assert ffn_w_name in [W.ffn_w1, W.ffn_w2, W.ffn_w3, W.ffn_w13, W.moe_w1, W.moe_w2]
inter_padding_size = src_weight.config.inter_padding_size
if ffn_w_name in [W.ffn_w1, W.ffn_w2, W.ffn_w3]:
assert len(weights) == 1
w_name = weights[0].name[:-len(W_SUFFIX)]
group_size = quant_algo.getGroupSize()
pad_div = 32 // quant_algo.getWeightBits()
is_awq = quant_algo.isAwq()
is_gptq = quant_algo.isGptq()
w: str = None
s: str = None
z: str = None
stack: Callable = None
act_w = None
if ffn_w_name == W.ffn_w2:
if src_weight.config.need_ffn_act_scale:
act_w_name = w_name.rsplit('.', 1)[0] + '.act.scales'
act_w = FfnAtomicWeight(
W.ffn_act_s, [CkptWeightInfo(act_w_name, identity)],
identity, config=src_weight.config)
return [
FfnAtomicWeight(
W.ffn_w2, [CkptWeightInfo(w_name + QW_SUFFIX, identity)],
functools.partial(pad,
inter_padding_size=inter_padding_size //
pad_div if is_gptq else inter_padding_size,
dim=0), data_type=torch.int32,
config=src_weight.config),
FfnAtomicWeight(
W.ffn_z2, [CkptWeightInfo(w_name + QZ_SUFFIX, identity)],
functools.partial(pad,
inter_padding_size=inter_padding_size //
group_size,
dim=0), data_type=torch.int32,
config=src_weight.config),
FfnAtomicWeight(
W.ffn_s2, [CkptWeightInfo(w_name + QS_SUFFIX, identity)],
functools.partial(pad,
inter_padding_size=inter_padding_size //
group_size,
dim=0),
config=src_weight.config),
act_w
]
elif ffn_w_name in [W.moe_w2, W.moe_w1]:
if ffn_w_name == W.moe_w1:
w, z, s = (W.moe_w1, W.moe_z1, W.moe_s1)
stack = stack_moe_w1
elif ffn_w_name == W.moe_w2:
w, z, s = (W.moe_w2, W.moe_z2, W.moe_s2)
stack = stack_
w_name = [weight.name[:-len(W_SUFFIX)] for weight in weights]
return [
MoeAtomicWeight(
w, [CkptWeightInfo(name + QW_SUFFIX, transpose) \
for name in w_name], stack, data_type=torch.int32,
config=src_weight.config),
MoeAtomicWeight(
z, [CkptWeightInfo(name + QZ_SUFFIX, transpose) \
for name in w_name], stack, data_type=torch.int32,
config=src_weight.config),
MoeAtomicWeight(
s, [CkptWeightInfo(name + QS_SUFFIX, transpose) \
for name in w_name], stack,
config=src_weight.config),
act_w
]
elif ffn_w_name == W.ffn_w13:
w, z, s = (W.ffn_w13, W.ffn_z13, W.ffn_s13)
w1_name = weights[0].name[:-len(W_SUFFIX)]
w3_name = weights[1].name[:-len(W_SUFFIX)]
return [
FfnAtomicWeight(
w, [CkptWeightInfo(w1_name + QW_SUFFIX, identity), CkptWeightInfo(w3_name + QW_SUFFIX, identity)],
functools.partial(pad_w13,
inter_padding_size=inter_padding_size //
pad_div if is_awq else inter_padding_size,
dim=1), data_type=torch.int32,
config=src_weight.config),
FfnAtomicWeight(
z, [CkptWeightInfo(w1_name + QZ_SUFFIX, identity), CkptWeightInfo(w3_name + QZ_SUFFIX, identity)],
functools.partial(pad_w13,
inter_padding_size=src_weight.config.inter_padding_size //
pad_div,
dim=1), data_type=torch.int32,
config=src_weight.config),
FfnAtomicWeight(
s, [CkptWeightInfo(w1_name + QS_SUFFIX, identity), CkptWeightInfo(w3_name + QS_SUFFIX, identity)],
functools.partial(pad_w13,
inter_padding_size=src_weight.config.inter_padding_size,
dim=1),
config=src_weight.config),
act_w
]
else:
w, z, s = (W.ffn_w1, W.ffn_z1,
W.ffn_s1) if ffn_w_name == W.ffn_w1 else (W.ffn_w3,
W.ffn_z3,
W.ffn_s3)
return [
FfnAtomicWeight(
w, [CkptWeightInfo(w_name + QW_SUFFIX, identity)],
functools.partial(pad,
inter_padding_size=inter_padding_size //
pad_div if is_awq else inter_padding_size,
dim=1), data_type=torch.int32,
config=src_weight.config),
FfnAtomicWeight(
z, [CkptWeightInfo(w_name + QZ_SUFFIX, identity)],
functools.partial(pad,
inter_padding_size=src_weight.config.inter_padding_size //
pad_div,
dim=1), data_type=torch.int32,
config=src_weight.config),
FfnAtomicWeight(
s, [CkptWeightInfo(w_name + QS_SUFFIX, identity)],
functools.partial(pad,
inter_padding_size=src_weight.config.inter_padding_size,
dim=1),
config=src_weight.config),
act_w
]