in tinynn/graph/quantization/quantizer.py [0:0]
def prepare_qconfig(self, graph: TraceGraph, backend: str):
"""Prepare qconfig for various configurations.
Args:
graph (TraceGraph): The computation graph of the model
backend (str, optional): The backend of quantization
"""
log.info('setting qat backend and call prepare_qat')
actual_backend = backend
if backend in ('onnx', 'tensorrt'):
actual_backend = 'qnnpack'
if not self.legacy_fq:
qconfig = torch_q.get_default_qat_qconfig(actual_backend)
else:
if LooseVersion(torch.__version__) >= '1.13.0':
# See https://github.com/pytorch/pytorch/pull/88876
qconfig = torch_q.QConfig(
activation=torch_q.FakeQuantize.with_args(
observer=torch_q.MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False
),
weight=torch_q.default_weight_fake_quant,
)
else:
version = None
if LooseVersion(torch.__version__) >= '1.12.0':
version = 0
qconfig = torch_q.get_default_qat_qconfig(actual_backend, version)
qconfig_c = None
if self.rounding_mode == 'tflite':
q_a = FakeQuantizeTFLite.with_args(*qconfig.activation.p.args, **qconfig.activation.p.keywords)
q_w = FakeQuantizeTFLite.with_args(*qconfig.weight.p.args, **qconfig.weight.p.keywords)
qconfig = torch_q.QConfig(q_a, q_w)
if backend == 'qnnpack':
if not self.asymmetric:
sym_fq = qconfig.activation.with_args(
observer=torch_q.MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False,
)
qconfig = torch_q.QConfig(sym_fq, qconfig.weight)
if not self.per_tensor:
sym_fq = qconfig.weight.with_args(
observer=torch_q.MovingAveragePerChannelMinMaxObserver.with_args(quant_min=-127, quant_max=127),
quant_min=-127,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0,
)
qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
elif backend == 'fbgemm':
fq_type = qconfig.weight.p.func
sym_fq = fq_type.with_args(
observer=torch_q.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False,
)
qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
elif backend in ('onnx', 'tensorrt'):
if not self.asymmetric:
sym_fq = qconfig.activation.with_args(
observer=torch_q.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False,
)
qconfig = torch_q.QConfig(sym_fq, qconfig.weight)
if not self.per_tensor:
sym_fq = qconfig.weight.with_args(
observer=torch_q.MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0,
)
qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq)
else:
log.warning(f'Quantization backend {self.backend} is not tested. Please use at your risk.')
torch.backends.quantized.engine = actual_backend
graph.module.qconfig = qconfig
if self.backend == 'qnnpack':
if qconfig_c is not None:
q = queue.Queue()
q.put(graph.module)
while not q.empty():
m = q.get()
if type(m).__name__ in (
'Conv2d',
'ConvBnReLU2d',
'ConvBn2d',
'ConvReLU2d',
'Conv1d',
'ConvBnReLU1d',
'ConvBn1d',
):
m.qconfig = qconfig_c
else:
for c in m.children():
q.put(c)
elif self.backend == 'fbgemm':
if qconfig_c is not None:
q = queue.Queue()
q.put(graph.module)
while not q.empty():
m = q.get()
if type(m).__name__ in ('Linear', 'LinearReLU'):
m.qconfig = qconfig_c
else:
for c in m.children():
q.put(c)
def _lstm_node(node, custom_data):
return isinstance(node.module, nn.LSTM)
if self.dynamic_lstm_quant:
lstm_nodes = graph.filter_forward_nodes(_lstm_node)
for node in lstm_nodes:
node.quantized = True
node.module.qconfig = torch_q.default_dynamic_qconfig