in tinynn/graph/quantization/quantizer.py [0:0]
def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, config: typing.Optional[dict] = None):
""" Constructs a new QATQuantizer object
Args:
model: The model to be quantized
dummy_input: A viable input to the model
work_dir (typing.Optional[str], optional): The working directory in which the intermediate files will be \
generated. Defaults to None, in which case "output" will be used.
config (typing.Optional[dict]): Options for the quantizer
"""
super().__init__()
if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel):
self.model = model.module
else:
self.model = model
self.dummy_input = dummy_input
self.work_dir = 'out' if work_dir is None else work_dir
self.parse_config(config)
if sys.platform == 'win32' and self.backend == 'qnnpack':
log.error('Quantization backend qnnpack is likely unsupported on Windows. Please use fbgemm instead.')
if self.backend not in ('fbgemm', 'qnnpack', 'onnx', 'tensorrt'):
log.warning(f'Quantization backend {self.backend} is not tested. Please use at your risk.')
if self.backend == 'fbgemm':
assert self.asymmetric, "Symmetric quantizaton for FBGEMM not supported"
assert (
not self.per_tensor
), "Per-tensor quantizaton for FBGEMM not supported, please use per-channel quantization instead"
if self.backend == 'tensorrt':
if self.asymmetric:
log.warning('Asymmetric quantizaton for TensorRT not supported')
if self.disable_requantization_for_cat is None:
if not self.per_tensor:
self.disable_requantization_for_cat = True
else:
self.disable_requantization_for_cat = False
self.extra_qparams_mappings = []
assert (
self.per_tensor or self.disable_requantization_for_cat
), "`disable_requantization_for_cat=True` is required for per-channel quantization"
if self.legacy_fq:
version = None
if type(self).__name__ == 'QATQuantizer':
version = '1.10.0'
elif type(self).__name__ == 'PostQuantizer':
version = '1.12.0'
if version is None or LooseVersion(torch.__version__) < version:
log.info(f'legacy_fq=True is only available for QATQuantizer and PostQuantizer with PyTorch {version}+')
self.legacy_fq = False
self.leaf_nodes = None
self.swap_nodes = None
self.train_mode_dict = {}
self.layerwise_config = CommentedMap()
self.effective_layers = []
self.layerwise_default = True
if config is not None and 'layerwise_config' in config:
self.layerwise_config.update(config['layerwise_config'])
self.lstm_origin_weight_dict = {}