in tinynn/graph/quantization/quantizer.py [0:0]
def prepare_qconfig(self, graph: TraceGraph, backend: str):
"""Prepare qconfig for various configurations.
Args:
graph (TraceGraph): The computation graph of the model
backend (str, optional): The backend of quantization
"""
log.info('setting qat backend and call prepare_qat')
if not self.legacy_fq or LooseVersion(torch.__version__) < '1.12.0':
qconfig = torch_q.get_default_qconfig(backend)
else:
qconfig = torch_q.get_default_qconfig(backend, 0)
qconfig_c = None
if self.backend == 'qnnpack':
if not self.asymmetric:
sym_fq = torch_q.HistogramObserver.with_args(
dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
)
qconfig = torch_q.QConfig(sym_fq, qconfig.weight)
if not self.per_tensor:
sym_fq = MinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
)
qconfig = torch_q.QConfig(qconfig.activation, sym_fq)
sym_fq = PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0
)
qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
elif self.backend == 'fbgemm':
sym_fq = torch_q.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
elif self.backend in ('onnx', 'tensorrt'):
if not self.asymmetric:
sym_fq = torch_q.HistogramObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
)
qconfig = torch_q.QConfig(sym_fq, qconfig.weight)
if not self.per_tensor:
sym_fq = torch_q.MinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False
)
qconfig = torch_q.QConfig(qconfig.activation, sym_fq)
sym_fq = torch_q.PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0
)
qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
else:
log.warning(f'Quantization backend {self.backend} is not tested. Please use at your risk.')
if self.algorithm != 'l2':
if self.algorithm == 'kl':
if self.backend == 'qnnpack':
alg_sym_fq = HistogramObserverKL.with_args(qscheme=torch.per_tensor_symmetric, reduce_range=False)
alg_asym_fq = HistogramObserverKL.with_args(reduce_range=False)
elif self.backend == 'fbgemm':
alg_sym_fq = HistogramObserverKL.with_args(qscheme=torch.per_tensor_symmetric, reduce_range=True)
alg_asym_fq = HistogramObserverKL.with_args(reduce_range=True)
else:
alg_sym_fq = qconfig.activation
alg_asym_fq = qconfig.activation
if not self.asymmetric:
qconfig = torch_q.QConfig(alg_sym_fq, qconfig.weight)
else:
qconfig = torch_q.QConfig(alg_asym_fq, qconfig.weight)
torch.backends.quantized.engine = backend
graph.module.qconfig = qconfig
if self.backend == 'qnnpack':
if qconfig_c is not None:
q = queue.Queue()
q.put(graph.module)
while not q.empty():
m = q.get()
if type(m).__name__ in (
'Conv2d',
'ConvBnReLU2d',
'ConvBn2d',
'ConvReLU2d',
'Conv1d',
'ConvBnReLU1d',
'ConvBn1d',
):
m.qconfig = qconfig_c
else:
for c in m.children():
q.put(c)
elif self.backend == 'fbgemm':
if qconfig_c is not None:
q = queue.Queue()
q.put(graph.module)
while not q.empty():
m = q.get()
if type(m).__name__ in ('Linear', 'LinearReLU'):
m.qconfig = qconfig_c
else:
for c in m.children():
q.put(c)
def _lstm_node(node, custom_data):
return isinstance(node.module, nn.LSTM)
if self.dynamic_lstm_quant:
lstm_nodes = graph.filter_forward_nodes(_lstm_node)
for node in lstm_nodes:
node.quantized = True
node.module.qconfig = torch_q.default_dynamic_qconfig