def prepare_qconfig()

in tinynn/graph/quantization/quantizer.py [0:0]


    def prepare_qconfig(self, graph: TraceGraph, backend: str):
        """Prepare qconfig for various configurations.

        Args:
            graph (TraceGraph): The computation graph of the model
            backend (str, optional): The backend of quantization
        """

        log.info('setting qat backend and call prepare_qat')
        if not self.legacy_fq or LooseVersion(torch.__version__) < '1.12.0':
            qconfig = torch_q.get_default_qconfig(backend)
        else:
            qconfig = torch_q.get_default_qconfig(backend, 0)
        qconfig_c = None
        if self.backend == 'qnnpack':
            if not self.asymmetric:
                sym_fq = torch_q.HistogramObserver.with_args(
                    dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
                )
                qconfig = torch_q.QConfig(sym_fq, qconfig.weight)
            if not self.per_tensor:
                sym_fq = MinMaxObserver.with_args(
                    dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
                )
                qconfig = torch_q.QConfig(qconfig.activation, sym_fq)
                sym_fq = PerChannelMinMaxObserver.with_args(
                    dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0
                )
                qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
        elif self.backend == 'fbgemm':
            sym_fq = torch_q.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
            qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
        elif self.backend in ('onnx', 'tensorrt'):
            if not self.asymmetric:
                sym_fq = torch_q.HistogramObserver.with_args(
                    dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
                )
                qconfig = torch_q.QConfig(sym_fq, qconfig.weight)
            if not self.per_tensor:
                sym_fq = torch_q.MinMaxObserver.with_args(
                    dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
                )
                qconfig = torch_q.QConfig(qconfig.activation, sym_fq)
                sym_fq = torch_q.PerChannelMinMaxObserver.with_args(
                    dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0
                )
                qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
        else:
            log.warning(f'Quantization backend {self.backend} is not tested. Please use at your risk.')

        if self.algorithm != 'l2':
            if self.algorithm == 'kl':
                if self.backend == 'qnnpack':
                    alg_sym_fq = HistogramObserverKL.with_args(qscheme=torch.per_tensor_symmetric, reduce_range=False)
                    alg_asym_fq = HistogramObserverKL.with_args(reduce_range=False)
                elif self.backend == 'fbgemm':
                    alg_sym_fq = HistogramObserverKL.with_args(qscheme=torch.per_tensor_symmetric, reduce_range=True)
                    alg_asym_fq = HistogramObserverKL.with_args(reduce_range=True)
                else:
                    alg_sym_fq = qconfig.activation
                    alg_asym_fq = qconfig.activation
                if not self.asymmetric:
                    qconfig = torch_q.QConfig(alg_sym_fq, qconfig.weight)
                else:
                    qconfig = torch_q.QConfig(alg_asym_fq, qconfig.weight)

        torch.backends.quantized.engine = backend
        graph.module.qconfig = qconfig
        if self.backend == 'qnnpack':
            if qconfig_c is not None:
                q = queue.Queue()
                q.put(graph.module)

                while not q.empty():
                    m = q.get()
                    if type(m).__name__ in (
                        'Conv2d',
                        'ConvBnReLU2d',
                        'ConvBn2d',
                        'ConvReLU2d',
                        'Conv1d',
                        'ConvBnReLU1d',
                        'ConvBn1d',
                    ):
                        m.qconfig = qconfig_c
                    else:
                        for c in m.children():
                            q.put(c)
        elif self.backend == 'fbgemm':
            if qconfig_c is not None:
                q = queue.Queue()
                q.put(graph.module)

                while not q.empty():
                    m = q.get()
                    if type(m).__name__ in ('Linear', 'LinearReLU'):
                        m.qconfig = qconfig_c
                    else:
                        for c in m.children():
                            q.put(c)

        def _lstm_node(node, custom_data):
            return isinstance(node.module, nn.LSTM)

        if self.dynamic_lstm_quant:
            lstm_nodes = graph.filter_forward_nodes(_lstm_node)
            for node in lstm_nodes:
                node.quantized = True
                node.module.qconfig = torch_q.default_dynamic_qconfig