def get_ffn_quant_weight_info()

in maga_transformer/model_loader/group_wise_quant_weight.py [0:0]


def get_ffn_quant_weight_info(src_weight: Union[FfnAtomicWeight, MoeAtomicWeight], quant_algo: Any) -> List[Union[FfnAtomicWeight, MoeAtomicWeight]]:
    weights = src_weight.weights
    ffn_w_name = src_weight.name
    assert weights[0].name.endswith(W_SUFFIX)
    assert ffn_w_name in [W.ffn_w1, W.ffn_w2, W.ffn_w3, W.ffn_w13, W.moe_w1, W.moe_w2]
    inter_padding_size = src_weight.config.inter_padding_size

    if ffn_w_name in [W.ffn_w1, W.ffn_w2, W.ffn_w3]:
        assert len(weights) == 1
    w_name = weights[0].name[:-len(W_SUFFIX)]
    group_size = quant_algo.getGroupSize()
    pad_div = 32 // quant_algo.getWeightBits()
    is_awq = quant_algo.isAwq()
    is_gptq = quant_algo.isGptq()
    w: str = None
    s: str = None
    z: str = None
    stack: Callable = None
    act_w = None
    if ffn_w_name == W.ffn_w2:
        if src_weight.config.need_ffn_act_scale:
            act_w_name = w_name.rsplit('.', 1)[0] + '.act.scales'
            act_w = FfnAtomicWeight(
                W.ffn_act_s, [CkptWeightInfo(act_w_name, identity)],
                identity, config=src_weight.config)
        return [
            FfnAtomicWeight(
                W.ffn_w2, [CkptWeightInfo(w_name + QW_SUFFIX, identity)],
                functools.partial(pad,
                                inter_padding_size=inter_padding_size //
                                pad_div if is_gptq else inter_padding_size,
                                dim=0), data_type=torch.int32,
                                config=src_weight.config),
            FfnAtomicWeight(
                W.ffn_z2, [CkptWeightInfo(w_name + QZ_SUFFIX, identity)],
                functools.partial(pad,
                                inter_padding_size=inter_padding_size //
                                group_size,
                                dim=0), data_type=torch.int32,
                                config=src_weight.config),
            FfnAtomicWeight(
                W.ffn_s2, [CkptWeightInfo(w_name + QS_SUFFIX, identity)],
                functools.partial(pad,
                                inter_padding_size=inter_padding_size //
                                group_size,
                                dim=0),
                                config=src_weight.config),
            act_w
        ]
    elif ffn_w_name in [W.moe_w2, W.moe_w1]:
        if ffn_w_name == W.moe_w1:
            w, z, s = (W.moe_w1, W.moe_z1, W.moe_s1)
            stack = stack_moe_w1
        elif ffn_w_name == W.moe_w2:
            w, z, s = (W.moe_w2, W.moe_z2, W.moe_s2)
            stack = stack_

        w_name = [weight.name[:-len(W_SUFFIX)] for weight in weights]
        return [
            MoeAtomicWeight(
                w, [CkptWeightInfo(name + QW_SUFFIX, transpose) \
                    for name in w_name], stack, data_type=torch.int32,
                    config=src_weight.config),
            MoeAtomicWeight(
                z, [CkptWeightInfo(name + QZ_SUFFIX, transpose) \
                    for name in w_name], stack, data_type=torch.int32,
                    config=src_weight.config),
           MoeAtomicWeight(
                s, [CkptWeightInfo(name + QS_SUFFIX, transpose) \
                    for name in w_name], stack,
                    config=src_weight.config),
           act_w
        ]
    elif ffn_w_name == W.ffn_w13:
        w, z, s = (W.ffn_w13, W.ffn_z13, W.ffn_s13)
        w1_name = weights[0].name[:-len(W_SUFFIX)]
        w3_name = weights[1].name[:-len(W_SUFFIX)]
        return [
            FfnAtomicWeight(
                w, [CkptWeightInfo(w1_name + QW_SUFFIX, identity), CkptWeightInfo(w3_name + QW_SUFFIX, identity)],
                functools.partial(pad_w13,
                                  inter_padding_size=inter_padding_size //
                                  pad_div if is_awq else inter_padding_size,
                                  dim=1), data_type=torch.int32,
                    config=src_weight.config),
            FfnAtomicWeight(
                z, [CkptWeightInfo(w1_name + QZ_SUFFIX, identity), CkptWeightInfo(w3_name + QZ_SUFFIX, identity)],
                functools.partial(pad_w13,
                                  inter_padding_size=src_weight.config.inter_padding_size //
                                  pad_div,
                                  dim=1), data_type=torch.int32,
                    config=src_weight.config),
            FfnAtomicWeight(
                s, [CkptWeightInfo(w1_name + QS_SUFFIX, identity), CkptWeightInfo(w3_name + QS_SUFFIX, identity)],
                functools.partial(pad_w13,
                                  inter_padding_size=src_weight.config.inter_padding_size,
                                  dim=1),
                    config=src_weight.config),
            act_w
        ]
    else:
        w, z, s = (W.ffn_w1, W.ffn_z1,
                   W.ffn_s1) if ffn_w_name == W.ffn_w1 else (W.ffn_w3,
                                                             W.ffn_z3,
                                                             W.ffn_s3)
        return [
            FfnAtomicWeight(
                w, [CkptWeightInfo(w_name + QW_SUFFIX, identity)],
                functools.partial(pad,
                                  inter_padding_size=inter_padding_size //
                                  pad_div if is_awq else inter_padding_size,
                                  dim=1), data_type=torch.int32,
                    config=src_weight.config),
            FfnAtomicWeight(
                z, [CkptWeightInfo(w_name + QZ_SUFFIX, identity)],
                functools.partial(pad,
                                  inter_padding_size=src_weight.config.inter_padding_size //
                                  pad_div,
                                  dim=1), data_type=torch.int32,
                    config=src_weight.config),
            FfnAtomicWeight(
                s, [CkptWeightInfo(w_name + QS_SUFFIX, identity)],
                functools.partial(pad,
                                  inter_padding_size=src_weight.config.inter_padding_size,
                                  dim=1),
                    config=src_weight.config),
            act_w
        ]