tinynn/graph/quantization/quantizer.py (3,279 lines of code) (raw):

import copy import functools import inspect import os import queue import re import sys import typing from distutils.version import LooseVersion import torch import torch.nn as nn import torch.nn.intrinsic as nni import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd import torch.quantization as torch_q from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.distributed import DistributedDataParallel from tinynn.graph.quantization.fake_quantize import ( FakeQuantizeBFloat16, FakeQuantizeTFLite, ) from tinynn.graph.quantization.modules import QGLU, QHardsigmoid, QPReLU, QSiLU, QLayerNorm, QRMSNorm from tinynn.graph.quantization.observer import ( HistogramObserverKL, MinMaxObserver, PerChannelMinMaxObserver, ) from tinynn.graph.quantization.qat_modules import ( Conv1d, ConvTranspose1d, ConvTranspose2d, ConvTransposeBn2d, ) if LooseVersion(torch.__version__) >= LooseVersion("1.13.0"): from tinynn.graph.quantization.quantizable.gru import GRU as QuantizableGRU from torch.ao.nn.quantizable.modules.rnn import LSTM as QuantizableLSTM else: QuantizableGRU = None QuantizableLSTM = None from tinynn.graph.tracer import ( ConstantNode, TraceFunction, TraceGraph, TraceNode, load_creation_funcs, module_constructor_lines, override_current_trace_graph, qualified_name, trace, ) from tinynn.util.train_util import get_logger, get_module_device from tinynn.util.util import import_from_path from tinynn.graph.quantization.utils import fake_quantize from . import fused_modules as fm try: import ruamel_yaml as yaml from ruamel_yaml import CommentedMap except ModuleNotFoundError: import ruamel.yaml as yaml from ruamel.yaml import CommentedMap # Fusable OPs for Quantize Aware Training FUSE_RULE_LIST = { (torch.nn.Conv1d, torch.nn.BatchNorm1d), (torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU), (torch.nn.Conv2d, torch.nn.BatchNorm2d), (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU), (torch.nn.Conv3d, torch.nn.BatchNorm3d), (torch.nn.Conv3d, torch.nn.BatchNorm3d, torch.nn.ReLU), (torch.nn.Conv1d, torch.nn.ReLU), (torch.nn.Conv2d, torch.nn.ReLU), (torch.nn.Conv3d, torch.nn.ReLU), (torch.nn.Linear, torch.nn.ReLU), (torch.nn.BatchNorm2d, torch.nn.ReLU), (torch.nn.BatchNorm3d, torch.nn.ReLU), (torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d), } FUSE_RULE_LIST_PTQ_ONLY = { (nn.Linear, nn.BatchNorm1d): '1.8.0', (nn.ConvTranspose1d, nn.BatchNorm1d): '1.11.0', (nn.ConvTranspose3d, nn.BatchNorm3d): '1.11.0', (nn.BatchNorm2d, nn.Conv2d): None, (nn.BatchNorm2d, nn.Conv2d, nn.ReLU): None, } FUSE_RULE_LIST_EXTRA = { (torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU6), (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6), (torch.nn.Conv2d, torch.nn.BatchNorm2d, 'clamp_with_fusion'), (torch.nn.Conv3d, torch.nn.BatchNorm3d, torch.nn.ReLU6), (torch.nn.Conv1d, torch.nn.ReLU6), (torch.nn.Conv2d, torch.nn.ReLU6), (torch.nn.Conv2d, 'clamp_with_fusion'), (torch.nn.Conv3d, torch.nn.ReLU6), (torch.nn.Linear, torch.nn.ReLU6), (torch.nn.Linear, 'clamp_with_fusion'), (torch.nn.Linear, torch.nn.BatchNorm1d, torch.nn.ReLU6), (torch.nn.Linear, torch.nn.BatchNorm1d, 'clamp_with_fusion'), (torch.nn.BatchNorm2d, torch.nn.ReLU6), (torch.nn.BatchNorm2d, 'clamp_with_fusion'), (torch.nn.BatchNorm3d, torch.nn.ReLU6), ('add', torch.nn.ReLU6), ('add', 'relu6'), ('add', 'clamp_with_fusion'), (torch.nn.ConvTranspose2d, torch.nn.ReLU), (torch.nn.ConvTranspose2d, torch.nn.ReLU6), (torch.nn.ConvTranspose2d, 'clamp_with_fusion'), (torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d, torch.nn.ReLU), (torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d, torch.nn.ReLU6), (torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d, 'clamp_with_fusion'), } FUSE_FALLBACK_DICT = {'clamp_with_fusion': 'clamp'} FUSE_QAT_MODULES = { nn.Conv1d: Conv1d, nn.ConvTranspose1d: ConvTranspose1d, nn.ConvTranspose2d: ConvTranspose2d, fm.ConvTransposeBn2d: ConvTransposeBn2d, } FUSE_QAT_MODULES_CUSTOM = {} FUSE_PTQ_MODULES_CUSTOM = {} if LooseVersion(torch.__version__) >= '1.13.0': from .quantizable.gru import GRU FUSE_QAT_MODULES_CUSTOM.update({nn.GRU: GRU}) FUSE_PTQ_MODULES_CUSTOM.update({nn.GRU: GRU}) FUSE_QAT_MODULES_CVT = {Conv1d: nnq.Conv1d} if hasattr(nnq, 'ConvTranspose1d'): FUSE_QAT_MODULES_CVT.update( { ConvTranspose1d: nnq.ConvTranspose1d, ConvTranspose2d: nnq.ConvTranspose2d, } ) REWRITE_TO_FUSE_RULE_LIST = { (torch.nn.Linear, torch.nn.BatchNorm1d), } KNOWN_QSTATS = { nn.Softmax: (0, 256.0), 'softmax': (0, 256.0), nn.LogSoftmax: (255, 16.0), 'log_softmax': (255, 16.0), } REWRITE_QUANTIZABLE_RULE_LIST = { ('abs',), ('sum',), ('bmm',), ('matmul',), ('truediv',), ('clamp_min',), ('clamp_max',), ('sqrt', 'reciprocal'), } UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST = { 'pad': '1.7.0', 'pow': None, 'truediv': None, 'sqrt': None, 'rsqrt': None, 'atan2': None, 'atan': None, 'sin': None, 'cos': None, 'hardsigmoid': None, 'silu': None, 'reciprocal': None, 'exp': None, 'layer_norm': None, 'instance_norm': None, 'softmax': None, 'log_softmax': None, 'mm': None, 'matmul': None, 'bmm': None, 'abs': None, 'sum': None, 'clamp_min': None, 'clamp_max': None, 'prelu': None, 'elu': None, 'glu': None, 'group_norm': None, 'log': None, 'std': None, 'var': None, 'norm': None, nn.LSTM: '1.13.0', nn.ConvTranspose2d: '1.7.0', nn.ConstantPad1d: '1.7.0', nn.ConstantPad2d: '1.7.0', nn.ConstantPad3d: '1.7.0', nn.ZeroPad2d: '1.7.0', nn.RNN: None, nn.GRU: '1.13.0', nn.LayerNorm: None, nn.InstanceNorm1d: None, nn.InstanceNorm2d: None, nn.Hardsigmoid: None, nn.Softmax: None, nn.LogSoftmax: None, nn.GLU: None, nn.PReLU: None, nn.GroupNorm: None, } Q_MODULES_MAPPING = { nn.PReLU: QPReLU, nn.GLU: QGLU, nn.Hardsigmoid: QHardsigmoid, } FUNCTIONAL_MODULE_MAPPING = { 'relu': nn.ReLU, 'relu6': nn.ReLU6, 'leaky_relu': nn.LeakyReLU, 'elu': nn.ELU, 'prelu': nn.PReLU, 'glu': nn.GLU, 'sigmoid': nn.Sigmoid, 'tanh': nn.Tanh, 'hardswish': nn.Hardswish, } TENSORRT_OBSERVED_NODES = [ (torch.nn, "Conv1d"), (torch.nn, "Conv2d"), (torch.nn, "Conv3d"), (torch.nn, "ConvTranspose1d"), (torch.nn, "ConvTranspose2d"), (torch.nn, "ConvTranspose3d"), (torch.nn, "Linear"), (torch.nn, "LSTM"), (torch.nn, "LSTMCell"), (torch.nn, "AvgPool1d"), (torch.nn, "AvgPool2d"), (torch.nn, "AvgPool3d"), (torch.nn, "AdaptiveAvgPool1d"), (torch.nn, "AdaptiveAvgPool2d"), (torch.nn, "AdaptiveAvgPool3d"), 'avg_pool2d', 'adaptive_avg_pool2d', ] if hasattr(nn, 'SiLU'): UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.update({nn.SiLU: None}) Q_MODULES_MAPPING.update({nn.SiLU: QSiLU}) FUNCTIONAL_MODULE_MAPPING.update({'silu': nn.SiLU}) if hasattr(nn, 'LayerNorm'): Q_MODULES_MAPPING.update({nn.LayerNorm: QLayerNorm}) if hasattr(nn, 'RMSNorm'): Q_MODULES_MAPPING.update({nn.RMSNorm: QRMSNorm}) # Processed QAT fuse rules processed_qat_rules = {} processed_ptq_rules = {} processed_extra_q_rules = {} processed_all_qat_rules = {} processed_all_ptq_rules = {} # Constant func names creation_func_names = [] # Processed rewrite rules for fusing processed_rewrite_to_fuse_rules = {} log = get_logger(__name__, 'WARNING') class QATQuantizer(object): rewrite_graph: bool force_overwrite: bool is_input_quantized: typing.Optional[typing.Tuple[bool]] quantized_input_stats: typing.Optional[typing.List[typing.Optional[typing.Tuple[float, float]]]] quantized_op_stats: typing.Optional[typing.List[typing.Optional[typing.Tuple[float, float]]]] set_quantizable_op_stats: bool backend: str remove_weights_after_load: bool asymmetric: bool per_tensor: bool disable_requantization_for_cat: bool dynamic_lstm_quant: bool rounding_mode: str leaf_nodes: typing.Optional[typing.List[nn.Module]] swap_nodes: typing.Optional[typing.List[typing.Tuple[nn.Module, nn.Module]]] legacy_fq: bool extra_tracer_opts: typing.Optional[typing.Dict] layerwise_config: typing.Dict[str, bool] layerwise_default: bool override_qconfig_func: typing.Optional[typing.Callable[[str, nn.Module], typing.Optional[torch_q.QConfig]]] effective_layers: typing.List[str] ignore_layerwise_config: bool fused_layerwise_config: bool inplace: bool train_mode_dict: typing.Dict[nn.Module, bool] def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, config: typing.Optional[dict] = None): """ Constructs a new QATQuantizer object Args: model: The model to be quantized dummy_input: A viable input to the model work_dir (typing.Optional[str], optional): The working directory in which the intermediate files will be \ generated. Defaults to None, in which case "output" will be used. config (typing.Optional[dict]): Options for the quantizer """ super().__init__() if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): self.model = model.module else: self.model = model self.dummy_input = dummy_input self.work_dir = 'out' if work_dir is None else work_dir self.parse_config(config) if sys.platform == 'win32' and self.backend == 'qnnpack': log.error('Quantization backend qnnpack is likely unsupported on Windows. Please use fbgemm instead.') if self.backend not in ('fbgemm', 'qnnpack', 'onnx', 'tensorrt'): log.warning(f'Quantization backend {self.backend} is not tested. Please use at your risk.') if self.backend == 'fbgemm': assert self.asymmetric, "Symmetric quantizaton for FBGEMM not supported" assert ( not self.per_tensor ), "Per-tensor quantizaton for FBGEMM not supported, please use per-channel quantization instead" if self.backend == 'tensorrt': if self.asymmetric: log.warning('Asymmetric quantizaton for TensorRT not supported') if self.disable_requantization_for_cat is None: if not self.per_tensor: self.disable_requantization_for_cat = True else: self.disable_requantization_for_cat = False self.extra_qparams_mappings = [] assert ( self.per_tensor or self.disable_requantization_for_cat ), "`disable_requantization_for_cat=True` is required for per-channel quantization" if self.legacy_fq: version = None if type(self).__name__ == 'QATQuantizer': version = '1.10.0' elif type(self).__name__ == 'PostQuantizer': version = '1.12.0' if version is None or LooseVersion(torch.__version__) < version: log.info(f'legacy_fq=True is only available for QATQuantizer and PostQuantizer with PyTorch {version}+') self.legacy_fq = False self.leaf_nodes = None self.swap_nodes = None self.train_mode_dict = {} self.layerwise_config = CommentedMap() self.effective_layers = [] self.layerwise_default = True if config is not None and 'layerwise_config' in config: self.layerwise_config.update(config['layerwise_config']) self.lstm_origin_weight_dict = {} def parse_config(self, config: typing.Optional[dict]): default_values = { 'rewrite_graph': True, 'force_overwrite': True, 'is_input_quantized': None, 'backend': 'qnnpack', 'remove_weights_after_load': False, 'asymmetric': True, 'per_tensor': True, 'disable_requantization_for_cat': None, 'quantized_input_stats': None, 'dynamic_lstm_quant': False, 'quantized_op_stats': None, 'set_quantizable_op_stats': False, 'rounding_mode': 'pytorch', 'algorithm': 'l2', 'fuse_only': False, 'legacy_fq': True, 'extra_tracer_opts': {}, 'fused_layerwise_config': True, 'ignore_layerwise_config': False, 'inplace': False, 'override_qconfig_func': None, 'quantize_op_action': {}, } if config is None: config = dict() for k, v in default_values.items(): actual_v = config.get(k, v) setattr(self, k, actual_v) def set_model_mode(self, model: nn.Module = None, preserve: bool = False): """Sets the mode (train/eval) of the model Args: model (nn.Module, optional): A PyTorch model. Defaults to None. preserve (bool, optional): Whether the old state is preserved. Defaults to False. """ if model is None: model = self.model if preserve: for m in model.modules(): training = getattr(m, 'training', None) if isinstance(training, bool): self.train_mode_dict[m] = training model.train() def restore_model_mode(self, model: nn.Module = None): """Restores the mode (train/eval) of the model Args: model (nn.Module, optional): A PyTorch model. Defaults to None. """ if model is None: model = self.model for m, training in self.train_mode_dict.items(): if training: m.train() else: m.eval() self.train_mode_dict.clear() def quantize(self) -> nn.Module: """Performs QAT rewrite and preparation Returns: nn.Module: The QAT-ready model """ # We need a model in training mode so that QAT could take place self.set_model_mode(preserve=True) if self.rewrite_graph: # Retrives the name of the model from type info model_name = type(self.model).__name__ model_name_q = f'Q{model_name}' model_name_q_lower = f'{model_name}_q'.lower() model_ns = f'tinynn_rewritten_models.{model_name_q}' model_code_path = os.path.join(self.work_dir, f'{model_name_q_lower}.py') model_weights_path = os.path.join(self.work_dir, f'{model_name_q_lower}.pth') config_path = os.path.join(self.work_dir, f'{model_name_q_lower}_config.yml') yaml_ = yaml.YAML() if len(self.layerwise_config) == 0 and not self.ignore_layerwise_config: if os.path.exists(config_path): with open(config_path, 'r') as f: self.layerwise_config = yaml_.load(f) if 'default' in self.layerwise_config: self.layerwise_default = self.layerwise_config['default'] del self.layerwise_config['default'] # Generate the code for the modified model # We will try to do some op fusion and rewriting in the `rewrite_quantize_graph` function. # By default, we will try to insert QuantStubs before every input and DeQuantStubs after every output in # the generated graph. If this doesn't suit your needs, e.g. you have intergal/quantized inputs or want to # skip the quantization rewrite for some ops, then you may modify the code generated freely and remember to # skip this step so it won't be overwritten. if self.force_overwrite or not os.path.exists(model_code_path) or not os.path.exists(model_weights_path): # After tracing the model, we will get a TraceGraph object graph = trace(self.model, self.dummy_input, **self.extra_tracer_opts) self.rewrite_quantize_graph(graph) if self.inplace: graph.inplace_commit() else: graph.generate_code(model_code_path, model_weights_path, model_name_q) os.makedirs(self.work_dir, exist_ok=True) with open(config_path, 'w') as f: yaml_.dump(self.layerwise_config, f) if self.inplace: rewritten_model = self.model else: # Import the new model rewritten_model = import_from_path(model_ns, model_code_path, model_name_q)() rewritten_model.load_state_dict(torch.load(model_weights_path)) device = get_module_device(self.model) if device is not None: rewritten_model.to(device=device) # Remove the weights file to save space if self.remove_weights_after_load: os.unlink(model_weights_path) # Restores the mode of the model self.restore_model_mode() # Set the model to training mode before tracing again, since we need to perform QAT self.set_model_mode(rewritten_model) rewritten_graph = trace(rewritten_model, self.dummy_input) # Update the model so that the original one can be released self.model = rewritten_model else: # After tracing the model, we will get a TraceGraph object graph = trace(self.model, self.dummy_input, **self.extra_tracer_opts) rewritten_graph = graph # Fuse the modules (e.g conv, bn, relu) in the computation graph according to the fuse rules. # By default, we assume all input tensors are of floating type. # If you want to use quantized/integral inputs, then you may need to pass in `is_input_quantized`. qat_model = self.prepare_qat(rewritten_graph, self.is_input_quantized, self.backend, self.fuse_only) return qat_model def prepare_qat_prep( self, graph: TraceGraph, is_input_quantized: typing.Optional[typing.Tuple[bool]] = None, backend: str = 'qnnpack', ): """Some common logic before calling torch.quantization.prepare[_qat] Args: graph (TraceGraph): The computation graph of the model is_input_quantized (typing.Union[typing.Tuple[bool]], optional): Whether the input tensor(s) is (are) \ quantized. Defaults to None. backend (str, optional): The backend of quantization. Defaults to 'qnnpack'. """ qat_analysis_queue = queue.Queue() visited = set() def _qat_analysis(node: TraceNode, quantized: bool): # Find quantized subgraphs in the whole computation graph if node.unique_name in visited: return visited.add(node.unique_name) if node in graph.output_nodes: return # TODO: Enable QAT analysis for TensorRT if self.backend == 'tensorrt': quantized = True node.quantized = True elif type(node.module) is torch_q.QuantStub: quantized = True node.quantized = quantized elif type(node.module) is torch_q.DeQuantStub: node.quantized = True quantized = False else: node.quantized = quantized log.debug(f"[QUANTIZED]{node.unique_name}:{quantized}") for i, n in enumerate(node.next_nodes): if type(n.module) is TraceFunction: if n.kind() in ('shape', 'size', 'dtype', 'device'): continue if n.kind() == 'expand_as' and i > 0: continue qat_analysis_queue.put((n, quantized)) if is_input_quantized is not None: assert len(is_input_quantized) == len(graph.input_nodes) for n, q in zip(graph.input_nodes, is_input_quantized): qat_analysis_queue.put((n, q)) else: for n in graph.input_nodes: qat_analysis_queue.put((n, False)) creation_func_names = load_creation_func_names() def _is_extra_constant_nodes(node, custom_data): return node.full_name() in creation_func_names extra_constant_nodes = graph.filter_forward_nodes(_is_extra_constant_nodes) def _is_params_in_module(node, custom_data): if len(node.prev_nodes) == 1 and len(node.next_nodes) == 1: if len(node.prev_tensors) == 1 and len(node.next_tensors) == 1: if isinstance(node.prev_tensors[0], nn.Module) and not isinstance( node.prev_tensors[0], nnq.FloatFunctional ): return True return False param_nodes = graph.filter_forward_nodes(_is_params_in_module) for n in graph.constant_nodes + extra_constant_nodes: qat_analysis_queue.put((n, not torch.is_floating_point(n.next_tensors[0]))) while not qat_analysis_queue.empty(): node, quantized = qat_analysis_queue.get() if not graph.quantized: graph.quantized = graph.quantized or quantized _qat_analysis(node, quantized) q_dict = {} for n in graph.forward_nodes: if isinstance(n.module, nn.Module): q_dict[n.module] = q_dict.get(n.module, False) or n.quantized for n in graph.forward_nodes: if n.module in q_dict: n.quantized = q_dict[n.module] for n in param_nodes: prev_node = n.prev_nodes[0] next_node = n.next_nodes[0] is_known_mod = prev_node.kind().__name__ in ( 'Conv1d', 'Conv2d', 'Linear', 'ConvTranspose1d', 'ConvTranspose2d', ) if is_known_mod and n.module.full_name == 'weight' and prev_node.quantized: if next_node.type() is torch_q.QuantStub: mod = nn.Sequential(torch_q.DeQuantStub(), torch_q.QuantStub()) orig_mod = next_node.module next_node.module = mod setattr(graph.module, next_node.original_name, next_node.module) graph.module_original_name_dict[id(mod)] = graph.module_original_name_dict[id(orig_mod)] graph.module_unique_name_dict[id(mod)] = graph.module_unique_name_dict[id(orig_mod)] continue qat_analysis_queue.put((n, not torch.is_floating_point(n.next_tensors[0]))) while not qat_analysis_queue.empty(): node, quantized = qat_analysis_queue.get() if not graph.quantized: graph.quantized = graph.quantized or quantized _qat_analysis(node, quantized) log.debug("qat analysis over") if not graph.quantized: return if isinstance(self, PostQuantizer): processed_rules = load_processed_ptq_rules() else: processed_rules = load_processed_qat_rules() is_fusable = functools.partial(self.is_fusable, current_rules=processed_rules, graph=graph) def _find_quantized_module_nodes(node: TraceNode, custom_node): # Find quantized module nodes return node.type() in Q_MODULES_MAPPING and node.quantized # Replace module nodes with our custom variants quantized_mod_nodes = graph.filter_forward_nodes(_find_quantized_module_nodes) type_dict = {} for node in quantized_mod_nodes: node_type = node.type() type_dict.setdefault(node_type, []) type_dict[node_type].append(node) for node_type, nodes in type_dict.items(): graph.update_submodule_in_nodes_from_predicate(nodes, Q_MODULES_MAPPING[node_type], self.inplace) custom_data = ([], set()) graph.filter_forward_nodes(is_fusable, custom_data, reverse=True) quant_list = custom_data[0] log.info(f'found nodes to fuse: {quant_list}') new_fuser_func = fm.gen_fuse_known_modules_wrapper( sys.modules['torch.quantization.fuse_modules'].fuse_known_modules ) if self.backend != 'tensorrt': is_qat = type(self) is QATQuantizer for quant_nodes in quant_list: if self.inplace: quant_nodes = [re.sub('get_submodule\\("(.*?)"\\)', '\\1', x) for x in quant_nodes] quant_nodes = [re.sub('\\[("|)(.*?)("|)\\]', '.\\2', x) for x in quant_nodes] if LooseVersion(torch.__version__) >= '1.11.0' and LooseVersion(torch.__version__) < '1.14.0': # See https://github.com/pytorch/pytorch/pull/88193 sys.modules['torch.quantization.fuse_modules']._fuse_modules( graph.module, quant_nodes, is_qat=is_qat, inplace=True, fuser_func=new_fuser_func ) elif is_qat and LooseVersion(torch.__version__) >= '1.14.0': # See https://github.com/pytorch/pytorch/issues/74028 torch.ao.quantization.fuse_modules_qat( graph.module, quant_nodes, fuser_func=new_fuser_func, inplace=True ) else: torch_q.fuse_modules(graph.module, quant_nodes, fuser_func=new_fuser_func, inplace=True) self.prepare_qconfig(graph, backend) self.override_qconfig(graph.module) def override_qconfig(self, module: nn.Module): """Sets qconfig according to the provided `override_qconfig_func` Args: module (nn.Module): The model to be prepared for quantization """ if self.override_qconfig_func is not None: q = queue.Queue() q.put(('', module)) while not q.empty(): n, m = q.get() new_qconfig = self.override_qconfig_func(n, m) if new_qconfig is not None: m.qconfig = new_qconfig else: for cn, cm in m.named_children(): fn = cn if len(n) > 0: fn = f'{n}.{fn}' q.put((fn, cm)) def prepare_qconfig(self, graph: TraceGraph, backend: str): """Prepare qconfig for various configurations. Args: graph (TraceGraph): The computation graph of the model backend (str, optional): The backend of quantization """ log.info('setting qat backend and call prepare_qat') actual_backend = backend if backend in ('onnx', 'tensorrt'): actual_backend = 'qnnpack' if not self.legacy_fq: qconfig = torch_q.get_default_qat_qconfig(actual_backend) else: if LooseVersion(torch.__version__) >= '1.13.0': # See https://github.com/pytorch/pytorch/pull/88876 qconfig = torch_q.QConfig( activation=torch_q.FakeQuantize.with_args( observer=torch_q.MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False ), weight=torch_q.default_weight_fake_quant, ) else: version = None if LooseVersion(torch.__version__) >= '1.12.0': version = 0 qconfig = torch_q.get_default_qat_qconfig(actual_backend, version) qconfig_c = None if self.rounding_mode == 'tflite': q_a = FakeQuantizeTFLite.with_args(*qconfig.activation.p.args, **qconfig.activation.p.keywords) q_w = FakeQuantizeTFLite.with_args(*qconfig.weight.p.args, **qconfig.weight.p.keywords) qconfig = torch_q.QConfig(q_a, q_w) if backend == 'qnnpack': if not self.asymmetric: sym_fq = qconfig.activation.with_args( observer=torch_q.MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, ) qconfig = torch_q.QConfig(sym_fq, qconfig.weight) if not self.per_tensor: sym_fq = qconfig.weight.with_args( observer=torch_q.MovingAveragePerChannelMinMaxObserver.with_args(quant_min=-127, quant_max=127), quant_min=-127, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0, ) qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq) elif backend == 'fbgemm': fq_type = qconfig.weight.p.func sym_fq = fq_type.with_args( observer=torch_q.MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, ) qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq) elif backend in ('onnx', 'tensorrt'): if not self.asymmetric: sym_fq = qconfig.activation.with_args( observer=torch_q.MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, ) qconfig = torch_q.QConfig(sym_fq, qconfig.weight) if not self.per_tensor: sym_fq = qconfig.weight.with_args( observer=torch_q.MovingAveragePerChannelMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0, ) qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq) else: log.warning(f'Quantization backend {self.backend} is not tested. Please use at your risk.') torch.backends.quantized.engine = actual_backend graph.module.qconfig = qconfig if self.backend == 'qnnpack': if qconfig_c is not None: q = queue.Queue() q.put(graph.module) while not q.empty(): m = q.get() if type(m).__name__ in ( 'Conv2d', 'ConvBnReLU2d', 'ConvBn2d', 'ConvReLU2d', 'Conv1d', 'ConvBnReLU1d', 'ConvBn1d', ): m.qconfig = qconfig_c else: for c in m.children(): q.put(c) elif self.backend == 'fbgemm': if qconfig_c is not None: q = queue.Queue() q.put(graph.module) while not q.empty(): m = q.get() if type(m).__name__ in ('Linear', 'LinearReLU'): m.qconfig = qconfig_c else: for c in m.children(): q.put(c) def _lstm_node(node, custom_data): return isinstance(node.module, nn.LSTM) if self.dynamic_lstm_quant: lstm_nodes = graph.filter_forward_nodes(_lstm_node) for node in lstm_nodes: node.quantized = True node.module.qconfig = torch_q.default_dynamic_qconfig def prepare_qat( self, graph: TraceGraph, is_input_quantized: typing.Optional[typing.Tuple[bool]] = None, backend: str = 'qnnpack', fuse_only: bool = False, ) -> torch.nn.Module: """Prepare model for QAT training Args: graph (TraceGraph): The computation graph of the model is_input_quantized (typing.Union[typing.Tuple[bool]], optional): Whether the input tensor(s) is (are) \ quantized. Defaults to None. backend (str, optional): The backend of quantization. Defaults to 'qnnpack'. fuse_only (bool, optional): Whether the returned model is only fused in PostQuantizer. Defaults to False. Returns: torch.nn.Module: The QAT-ready model """ graph.module.train() self.prepare_qat_prep(graph, is_input_quantized, backend) if not graph.quantized: log.warning('Graph is not quantized, skip preparation') return graph.module # Unfornately, the suggested way below will try to fuse all the modules # even if some of the nodes are not in a quantized computation graph. # So we wrote some alternatives for the function. # torch.quantization.prepare_qat(graph.module, inplace=True) if hasattr(torch_q, 'get_default_qat_module_mappings'): mapping = torch_q.get_default_qat_module_mappings() elif hasattr(torch_q, 'get_qat_module_mappings'): mapping = copy.deepcopy(torch_q.get_qat_module_mappings()) else: mapping = copy.deepcopy(torch_q.DEFAULT_QAT_MODULE_MAPPING) if self.dynamic_lstm_quant: mapping = dict(mapping) mapping.update({nn.LSTM: nnqd.LSTM}) mapping.update(FUSE_QAT_MODULES) q_dict = {} for n in graph.forward_nodes: if isinstance(n.module, nn.Module): q_dict[n.module] = q_dict.get(n.module, False) or n.quantized non_quantized_mods = set(m for m, s in q_dict.items() if s is False) if LooseVersion(torch.__version__) < LooseVersion("1.7.0"): model = torch_q.prepare(graph.module, inplace=True) for m in non_quantized_mods: if hasattr(m, "_forward_hooks"): if len(m._forward_hooks) > 0: m._forward_hooks.popitem() if hasattr(m, "qconfig"): delattr(m, "qconfig") if hasattr(m, "activation_post_process"): delattr(m, "activation_post_process") if hasattr(graph.module, 'activation_post_process'): if hasattr(graph.module, "_forward_hooks"): if len(graph.module._forward_hooks) > 0: graph.module._forward_hooks.popitem() delattr(graph.module, "activation_post_process") torch_q.convert(model, mapping, inplace=True) else: torch_q.propagate_qconfig_(graph.module, qconfig_dict=None) for m in non_quantized_mods: if hasattr(m, "qconfig"): delattr(m, "qconfig") model = torch_q.convert(graph.module, mapping=mapping, inplace=True, remove_qconfig=False) if self.dynamic_lstm_quant: mapping.pop(nn.LSTM) if LooseVersion(torch.__version__) >= LooseVersion("1.13.0"): torch_q.propagate_qconfig_(model, qconfig_dict=None) for m in non_quantized_mods: if hasattr(m, "qconfig"): delattr(m, "qconfig") prepare_custom_config_dict = torch.ao.quantization.get_default_custom_config_dict() custom_module_class_mapping = prepare_custom_config_dict.get( "float_to_observed_custom_module_class", {} ) custom_module_class_mapping.update(FUSE_QAT_MODULES_CUSTOM) qconfig_propagation_list = torch_q.get_default_qconfig_propagation_list() from .quantizable import lstm, gru orig_from_float = torch.ao.nn.quantizable.LSTM.from_float torch.ao.nn.quantizable.LSTM.from_float = lstm.from_float GRU.from_float = gru.from_float def patch_observer_set(orig_func): def new_no_observer_set(): return set(FUSE_QAT_MODULES_CUSTOM.values()) | orig_func() return new_no_observer_set orig_no_observer_set = sys.modules['torch.ao.quantization.quantize'].no_observer_set sys.modules['torch.ao.quantization.quantize'].no_observer_set = patch_observer_set(orig_no_observer_set) if hasattr(torch_q, 'add_observer_'): add_observer_func = torch_q.add_observer_ else: add_observer_func = sys.modules['torch.ao.quantization.quantize']._add_observer_ add_observer_func( model, qconfig_propagation_list, set(mapping.values()), custom_module_class_mapping=custom_module_class_mapping, ) torch.ao.nn.quantizable.LSTM.from_float = orig_from_float sys.modules['torch.ao.quantization.quantize'].no_observer_set = orig_no_observer_set else: torch_q.prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) for m in non_quantized_mods: if hasattr(m, "_forward_hooks"): if len(m._forward_hooks) > 0: m._forward_hooks.popitem() if hasattr(m, "qconfig"): delattr(m, "qconfig") if hasattr(m, "activation_post_process"): delattr(m, "activation_post_process") if not self.per_tensor: if self.backend == 'qnnpack': for n, m in graph.module.named_modules(): if n.endswith('.weight_fake_quant'): observer = getattr(m, 'activation_post_process', None) if observer is not None: m.quant_min = -127 m.quant_max = 127 observer.quant_min = -127 observer.quant_max = 127 self.extra_qat_fusion_postprocess(graph) if self.disable_requantization_for_cat: self.disable_requantization_for_cat_pass(graph) if self.quantized_input_stats is not None: self.prepare_quantized_inputs_pass(graph) if self.set_quantizable_op_stats: if self.quantized_op_stats is None: self.quantized_op_stats = {} self.quantized_op_stats.update(KNOWN_QSTATS) if self.quantized_op_stats is not None: self.prepare_quantized_ops_pass(graph) if self.backend == 'tensorrt': q = queue.Queue() q.put(('', graph.module)) while not q.empty(): s, m = q.get() if isinstance(m, torch_q.QuantStub): m.apply(torch.quantization.enable_fake_quant) m.apply(torch.quantization.enable_observer) elif hasattr(m, 'activation_post_process'): m.activation_post_process.apply(torch.quantization.disable_fake_quant) m.activation_post_process.apply(torch.quantization.disable_observer) if hasattr(m, 'weight_fake_quant'): m.weight_fake_quant.apply(torch.quantization.enable_fake_quant) m.weight_fake_quant.apply(torch.quantization.enable_observer) else: for n, c in m.named_children(): q.put((f'{s}{n}', c)) if self.backend == 'onnx': self.leaf_nodes = [] self.swap_nodes = [] q = queue.Queue() for node in graph.output_nodes: for prev_node in node.prev_nodes: q.put((prev_node, False)) while not q.empty(): n, state = q.get() if isinstance(n.module, nn.Module): orig_name = graph.module_original_name_dict.get(id(n.module)) new_mod, _ = graph.get_submodule_with_parent_from_name(orig_name, self.inplace) if isinstance(new_mod, torch_q.DeQuantStub): state = True else: if state: if isinstance(new_mod, torch_q.QuantStub): state = False elif isinstance(new_mod, nn.Module) and hasattr(new_mod, 'activation_post_process'): self.leaf_nodes.append(new_mod) state = False elif ( isinstance(new_mod, nn.Sequential) and type(new_mod).__module__.startswith(nni.__name__) and len(new_mod) > 0 and hasattr(new_mod[-1], 'activation_post_process') ): self.leaf_nodes.append(new_mod[-1]) state = False for pn in n.prev_nodes: q.put((pn, state)) q = queue.Queue() visited = set() for node in graph.input_nodes: q.put((node, None, False, 0)) while not q.empty(): n, prev_q_mod, state, idx = q.get() key = f'{n.unique_name}:{idx}' if key in visited: continue else: visited.add(key) q_mod = prev_q_mod if n.quantized: if isinstance(n.module, nn.Module): orig_name = graph.module_original_name_dict.get(id(n.module)) new_mod, _ = graph.get_submodule_with_parent_from_name(orig_name, self.inplace) if isinstance(new_mod, nn.Module) and hasattr(new_mod, 'activation_post_process'): q_mod = new_mod elif ( isinstance(new_mod, nn.Sequential) and type(new_mod).__module__.startswith(nni.__name__) and len(new_mod) > 0 and hasattr(new_mod[-1], 'activation_post_process') ): q_mod = new_mod[-1] elif isinstance(new_mod, torch_q.DeQuantStub): q_mod = new_mod elif type(new_mod) is not nn.Identity: state = True else: is_prev_float_functional = ( len(n.prev_nodes) > 1 and n.prev_nodes[0].type() is torch.nn.quantized.FloatFunctional ) if is_prev_float_functional: q_mod = getattr(n.prev_nodes[0].module, n.kind()) else: state = True if state and prev_q_mod is not None and q_mod != prev_q_mod: self.swap_nodes.append((prev_q_mod, q_mod, idx)) state = False for next_n in n.next_nodes: idx = next_n.prev_nodes.index(n) q.put((next_n, q_mod, state, idx)) if self.quantize_op_action.get(nn.LSTM, None) and self.backend == 'qnnpack': self.fake_quantize_lstm_weights() return graph.module def extra_qat_fusion_postprocess(self, graph): # Process additional fusable nodes processed_extra_q_rules = load_processed_extra_q_rules() is_extra_fusable = functools.partial( self.is_fusable, current_rules=processed_extra_q_rules, graph=graph, ) custom_data = ([], set()) graph.filter_forward_nodes(is_extra_fusable, custom_data, reverse=True) quant_list = custom_data[0] log.debug(f'Extra q postprocess for nodes: {quant_list}') rev_dict = dict((v, k) for k, v in graph.module_original_name_dict.items()) for quant_nodes in quant_list: for orig_name in quant_nodes: if orig_name in rev_dict: new_mod, _ = graph.get_submodule_with_parent_from_name(orig_name, self.inplace) acp = getattr(new_mod, 'activation_post_process', None) if acp is not None: torch.quantization.disable_fake_quant(acp) torch.quantization.disable_observer(acp) activ_name = quant_nodes[-1] if activ_name in rev_dict: unique_name = graph.module_unique_name_dict[rev_dict[activ_name]] else: unique_name = activ_name activ_name = None node = graph.nodes_map[unique_name] activ_type = node.kind() post_dq = node.next_nodes[0].module post_q = node.next_nodes[0].next_nodes[0].module assert isinstance(post_dq, torch_q.DeQuantStub) assert isinstance(post_q, torch_q.QuantStub) dq_name = graph.module_original_name_dict[id(post_dq)] q_name = graph.module_original_name_dict[id(post_q)] post_acp = getattr(post_q, 'activation_post_process', None) assert post_acp is not None self.extra_qparams_mappings.append([acp, post_acp, dq_name, q_name, activ_name, activ_type]) def disable_requantization_for_cat_pass(self, graph): def _find_quantized_cat_nodes(node: TraceNode, custom_node): # Find quantized cat nodes return node.type() == 'cat' and node.quantized # For cat nodes, the `activation_post_process` around it needs to be unified quantized_cat_nodes = graph.filter_forward_nodes(_find_quantized_cat_nodes) q = queue.Queue() visited_center = set() for n in quantized_cat_nodes: q.put((n, 'both', 0)) parents = [] names = [] props = [] visited_other = dict() while not q.empty(): n, mode, fq_count = q.get() if ( n.kind() in ('shape', 'size') or n.unique_name in visited_center or visited_other.get(n.unique_name, 2) <= fq_count ): continue if n.type() == 'cat': visited_center.add(n.unique_name) else: visited_other[n.unique_name] = fq_count new_fq_count = fq_count if isinstance(n.module, nn.Module): is_prev_float_functional = False orig_name = graph.module_original_name_dict.get(id(n.module)) new_mod, parent = graph.get_submodule_with_parent_from_name(orig_name, self.inplace) prop = orig_name.split('.')[-1] if QuantizableLSTM is not None and isinstance(new_mod, QuantizableLSTM): if new_fq_count == 0: if new_mod.bidirectional is False: parents.append(new_mod.layers[-1].layer_fw.cell.ogate_cy) names.append(f'{orig_name}.layer_fw.cell.ogate_cy.activation_post_process') props.append('activation_post_process') else: parents.append(new_mod.layers[-1].layer_fw.cell.ogate_cy) names.append(f'{orig_name}.layer_fw.cell.ogate_cy.activation_post_process') props.append('activation_post_process') parents.append(new_mod.layers[-1].layer_bw.cell.ogate_cy) names.append(f'{orig_name}.layer_bw.cell.ogate_cy.activation_post_process') props.append('activation_post_process') new_fq_count += 1 elif QuantizableGRU is not None and isinstance(new_mod, QuantizableGRU): if new_fq_count == 0: if new_mod.bidirectional is False: parents.append(new_mod.layers[-1].layer_fw.cell.add4) names.append(f'{orig_name}.layer_fw.cell.add4.activation_post_process') props.append('activation_post_process') else: parents.append(new_mod.layers[-1].layer_fw.cell.add4) names.append(f'{orig_name}.layer_bw.cell.add4.activation_post_process') props.append('activation_post_process') parents.append(new_mod.layers[-1].layer_bw.cell.add4) names.append(f'{orig_name}.layer_bw.cell.add4.activation_post_process') props.append('activation_post_process') new_fq_count += 1 elif isinstance(new_mod, (torch_q.FakeQuantize, torch_q.ObserverBase)): if new_fq_count == 0: parents.append(parent) names.append(orig_name) props.append(prop) new_fq_count += 1 elif hasattr(new_mod, 'activation_post_process'): if new_fq_count == 0: parents.append(new_mod) names.append(f'{orig_name}.activation_post_process') props.append('activation_post_process') new_fq_count += 1 elif ( isinstance(new_mod, nn.Sequential) and type(new_mod).__module__.startswith(nni.__name__) and len(new_mod) > 0 and hasattr(new_mod[-1], 'activation_post_process') ): if new_fq_count == 0: parents.append(new_mod[-1]) names.append(f'{orig_name}[-1].activation_post_process') props.append('activation_post_process') new_fq_count += 1 if isinstance(new_mod, (torch_q.DeQuantStub, torch_q.QuantStub)): new_fq_count = 2 else: is_prev_float_functional = ( len(n.prev_nodes) > 1 and n.prev_nodes[0].type() is torch.nn.quantized.FloatFunctional ) if n.type() == 'cat': mode = 'both' fq_count = 0 new_fq_count = 0 if is_prev_float_functional: m = n.prev_nodes[0].module orig_name = graph.module_original_name_dict.get(id(m)) if new_fq_count == 0: parents.append(m) names.append(f'{orig_name}.activation_post_process') props.append('activation_post_process') new_fq_count += 1 if mode in ('both', 'down'): fq_up = fq_count fq_down = new_fq_count elif mode == 'up': fq_up = new_fq_count fq_down = fq_count if mode == 'up' and len(n.next_nodes) > 1: mode = 'both' fq_down += 1 if mode in ('both', 'up'): for i, node in enumerate(n.prev_nodes): if is_prev_float_functional and i == 0: continue if fq_up < 2: q.put((node, 'up', fq_up)) if mode in ('both', 'down'): for node in n.next_nodes: if fq_down < 2: q.put((node, 'down', fq_down)) if len(names) > 1: log.debug(f'Unifying the following nodes into one: {", ".join(names)}') unified = getattr(parents[0], props[0]) for parent, prop in zip(parents[1:], props[1:]): setattr(parent, prop, unified) def prepare_quantized_inputs_pass(self, graph): if self.quantized_input_stats is not None: assert len(self.quantized_input_stats) == len(graph.input_nodes), ( f"quantized_input_stats contains {len(self.quantized_input_stats)} elements, but" f" {len(graph.input_nodes)} is expected" ) for qstats, node in zip(self.quantized_input_stats, graph.input_nodes): if qstats is not None: assert ( isinstance(qstats, (list, tuple)) and len(qstats) == 2 and all((isinstance(q, (int, float)) for q in qstats)) ), "quantized_input_stats format: [(mean_1, std_1), (mean_2, std_2), ..., (mean_n, std_n)]" for next_node in node.next_nodes: if isinstance(next_node.module, torch_q.QuantStub): self.set_module_quantization(next_node.module, qstats[0], qstats[1]) def prepare_quantized_ops_pass(self, graph): if self.quantized_op_stats: def _find_quantized_ops_with_type(node, custom_data): if isinstance(node.module, nn.Module) and hasattr(node.module, 'activation_post_process'): qstats = self.quantized_op_stats.get(node.kind(), None) if qstats is not None: custom_data.append((qstats, node)) return True else: if isinstance(node.module, torch_q.QuantStub): qstats = self.quantized_op_stats.get(node.prev_nodes[0].kind(), None) if qstats is not None: custom_data.append((qstats, node)) return True return False node_with_qstats = [] graph.filter_forward_nodes(_find_quantized_ops_with_type, node_with_qstats) for qstats, node in node_with_qstats: if qstats is not None: assert ( isinstance(qstats, (list, tuple)) and len(qstats) == 2 and all((isinstance(q, (int, float)) for q in qstats)) ), ( "quantized_op_stats format: {'op_0': (mean_1, std_1), 'op_1': (mean_2, std_2), ..., 'op_n':" " (mean_n, std_n)}" ) self.set_module_quantization(node.module, qstats[0], qstats[1]) def set_module_quantization(self, module: nn.Module, mean: int, std: float): acp = getattr(module, 'activation_post_process', None) if acp is not None: fq_base_cls = getattr(torch_q, 'FakeQuantizeBase', torch_q.FakeQuantize) if isinstance(acp, fq_base_cls): fake_quant = acp module.apply(torch_q.disable_observer) scale = torch.tensor(1.0 / std, dtype=fake_quant.scale.dtype) offset = torch.tensor(mean, dtype=fake_quant.zero_point.dtype) fake_quant.scale.copy_(scale) fake_quant.zero_point.copy_(offset) quant_min = fake_quant.quant_min quant_max = fake_quant.quant_max observer = getattr(fake_quant, 'activation_post_process', None) elif isinstance(acp, torch_q.ObserverBase): observer = acp # We cannot use `disable_observer` which is designed for `FakeQuant` modules`. # Instead, we need to monkey-patch the forward function of the observer. identity = nn.Identity() observer.forward = identity.forward scale = torch.tensor(1.0 / std, dtype=torch.float32) offset = torch.tensor(mean, dtype=torch.int32) quant_min = observer.quant_min quant_max = observer.quant_max else: log.warning(f'Given module {type(module)} doesn\'t seem to be a quantized module') return if observer is not None: if quant_min is None and quant_max is None: if observer.reduce_range: quant_min, quant_max = 0, 127 else: quant_min, quant_max = 0, 255 observer.min_val = scale * (quant_min - offset) observer.max_val = scale * (quant_max - offset) def rescale_activations_with_quant_min_max(self, quant_min: int, quant_max: int) -> None: """Rescales activations with provided quant_min and quant_max""" for n, m in self.model.named_modules(): if '.weight_fake_quant' in n: continue if isinstance(m, torch.nn.Sigmoid) and quant_min == 0 and quant_max == 65535: m.activation_post_process.scale = m.activation_post_process.scale * 2 m.activation_post_process.zero_point = torch.zeros_like(m.activation_post_process.zero_point) + 32768 log.info(f"Rescale {n}(sigmoid mod) quant param to [1/32768, 0] to align TFLite.") if isinstance(m, torch.quantization.FakeQuantize): observer = getattr(m, 'activation_post_process', None) if observer is not None: old_quant_range = m.quant_max - m.quant_min + 1 m.quant_min = quant_min m.quant_max = quant_max new_quant_range = m.quant_max - m.quant_min + 1 m.scale = m.scale * old_quant_range / new_quant_range zero_point = new_quant_range // 2 if m.zero_point != 0 else 0 m.zero_point = torch.zeros_like(m.zero_point) + zero_point observer.quant_min = quant_min observer.quant_max = quant_max def rewrite_quantize_graph_for_tensorrt(self, graph: TraceGraph) -> None: """Rewrites the computation graph for TensorRT quantization""" processed_types = set() for types in TENSORRT_OBSERVED_NODES: if isinstance(types, tuple): if hasattr(types[0], types[1]): processed_types.add(getattr(types[0], types[1])) else: processed_types.add(types) def _is_observed_nodes_for_tensorrt(node, custom_data): return node.kind() in processed_types observed_nodes = graph.filter_forward_nodes(_is_observed_nodes_for_tensorrt) node_map = dict() for idx, node in enumerate(observed_nodes): fake_quant_cls = torch_q.QuantStub assert node.rev_index is False prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values() for inner_idx, prev_node in enumerate(prev_nodes): if prev_node.kind() in ('shape', 'device', 'size', 'dtype'): continue prev_tensor_ptrs = [] for pt in node.prev_tensors: for nt in prev_node.next_tensors: if isinstance(nt, (list, tuple)): for k, ntt in enumerate(nt): if id(pt) == id(ntt): if id(pt) not in prev_tensor_ptrs: prev_tensor_ptrs.append(id(pt)) break elif id(pt) == id(nt): if id(pt) not in prev_tensor_ptrs: prev_tensor_ptrs.append(id(pt)) break for ptr_idx, ptr in enumerate(prev_tensor_ptrs): if ptr in node_map: fake_quant = node_map[ptr] graph.insert_between(prev_node, node, fake_quant, move_idx=True, tensor_ptrs=set([ptr])) else: fake_quant = fake_quant_cls() fake_quant_name = f'fake_quant_inner_{idx}_{inner_idx}_{ptr_idx}' graph.module_unique_name_dict[id(fake_quant)] = fake_quant_name graph.module_original_name_dict[id(fake_quant)] = fake_quant_name module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()' graph.insert_between(prev_node, node, fake_quant, move_idx=True, tensor_ptrs=set([ptr])) node_map[ptr] = graph.nodes_map[fake_quant_name] graph.quantized = True graph.recompute_forward_order() def rewrite_quantize_graph(self, graph: TraceGraph) -> None: """Rewrites the computation graph for quantization""" if graph.quantized: return graph_quantized = False for n in graph.forward_nodes: if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub): graph_quantized = True break for n in graph.other_init_nodes: if n.type() is nnq.FloatFunctional: graph_quantized = True break if graph_quantized: log.warning( 'Graph is quantized. No need to rewrite. Please pass in `config={"rewrite_graph": False}` to suppress' ' this warning' ) return if self.backend == 'tensorrt': return self.rewrite_quantize_graph_for_tensorrt(graph) if self.layerwise_default is False: for n in graph.forward_nodes: self.layerwise_config.setdefault(n.unique_name, False) creation_func_names = load_creation_func_names() def _is_extra_constant_nodes(node, custom_data): return node.full_name() in creation_func_names extra_constant_nodes = graph.filter_forward_nodes(_is_extra_constant_nodes) def _is_int_to_float_nodes(node, custom_data): if node.full_name() in creation_func_names: return False if len(node.prev_nodes) == 1 and len(node.next_nodes) == 1: if len(node.prev_tensors) == 1 and len(node.next_tensors) == 1: if node.prev_nodes[0].kind() == 'shape' and node.prev_nodes[0].module.is_property: return False if ( isinstance(node.prev_tensors[0], torch.Tensor) and isinstance(node.next_tensors[0], torch.Tensor) and node.prev_tensors[0].dtype in (torch.int32, torch.int64) and torch.is_floating_point(node.next_tensors[0]) ): return True else: return False int_to_float_nodes = graph.filter_forward_nodes(_is_int_to_float_nodes) # When converting float-tensor to int16/int32/int64, we need to add 'fake_dequant' node before convert-node. def _is_float_to_non_float_nodes(node, custom_data): if isinstance(node.module, TraceFunction) and node.module.kind in ('shape', 'size'): return False return ( len(node.next_tensors) == 1 and len(node.prev_tensors) > 0 and isinstance(node.prev_tensors[0], torch.Tensor) and isinstance(node.next_tensors[0], torch.Tensor) and torch.is_floating_point(node.prev_tensors[0]) and not torch.is_floating_point(node.next_tensors[0]) ) float_to_non_float_nodes = graph.filter_forward_nodes(_is_float_to_non_float_nodes) def _is_params_in_module(node, custom_data): if len(node.prev_nodes) == 1 and len(node.next_nodes) == 1: if len(node.prev_tensors) == 1 and len(node.next_tensors) == 1: if isinstance(node.prev_tensors[0], nn.Module) and not isinstance( node.prev_tensors[0], nnq.FloatFunctional ): return True return False param_nodes = graph.filter_forward_nodes(_is_params_in_module) for node in param_nodes: prev_node = node.prev_nodes[0] is_known_mod = prev_node.kind().__name__ in ( 'Conv1d', 'Conv2d', 'Linear', 'ConvTranspose1d', 'ConvTranspose2d', ) prop = node.module.full_name if is_known_mod and prop in ('weight', 'bias') and self.layerwise_config.get(prev_node.unique_name, True): node.module.is_class = False node.module.is_property = False node.module.full_name = 'tinynn.graph.quantization.utils.get_parameter' node.module.args_template = f'{{}}, "{prop}"' node.module.args_template_no_self = f'"{prop}"' node.module.args_offset = 0 node.module.update_args_string() # First, we insert the QuantStub nodes for every input/constant node for idx, node in reversed( list( enumerate( graph.input_nodes + graph.constant_nodes + extra_constant_nodes + int_to_float_nodes + param_nodes ) ) ): if node.next_tensors[0].dtype in (torch.int32, torch.int64): continue fake_quant = torch_q.QuantStub() graph.module_unique_name_dict[id(fake_quant)] = f'fake_quant_{idx}' graph.module_original_name_dict[id(fake_quant)] = f'fake_quant_{idx}' fake_quant_cls = type(fake_quant) module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()' graph.insert_after(node, fake_quant) # Second, we insert the DeQuantStub nodes for every output node for idx, node in enumerate(graph.output_nodes + float_to_non_float_nodes): fake_dequant_cls = torch_q.DeQuantStub if node.rev_index: modules = [] for rev_idx in range(len(node.prev_nodes)): if node.prev_tensors[rev_idx].dtype in (torch.int32, torch.int64): fake_dequant = nn.Identity() else: fake_dequant = fake_dequant_cls() graph.module_unique_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}_{rev_idx}' graph.module_original_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}_{rev_idx}' module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()' modules.append(fake_dequant) graph.insert_before(node, modules, move_idx=True) else: if node.prev_tensors[0].dtype in (torch.int32, torch.int64): continue fake_dequant = fake_dequant_cls() graph.module_unique_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}' graph.module_original_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}' module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()' if len(node.prev_nodes) > 1: # Insert 'fake_dequant' node before type conversion operators graph.insert_between(node.prev_nodes[0], node, fake_dequant, move_idx=True) else: graph.insert_before(node, fake_dequant, move_idx=True) # Third, we rewrite neg/sub/div using supported functions(e.g add, mul) def _is_neg_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind == 'neg' and torch.is_floating_point(cur_module.prev_tensors[0]) and self.layerwise_config.get(node.unique_name, True) ) neg_nodes = graph.filter_forward_nodes(_is_neg_node) log.info(f'rewriting neg for {[node.unique_name for node in neg_nodes]}') for idx, node in enumerate(neg_nodes): node.module.func_type = '__mul__' node.module.kind = 'mul' full_name_parts = node.module.full_name.split('.') full_name_parts[-1] = node.module.func_type node.module.full_name = '.'.join(full_name_parts) with override_current_trace_graph(graph): node.module.parse_args(node.prev_tensors[0], -1.0) def _is_div_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) # Current, only the following condition could be handled. # a / constant => a * (1 / constant) if cur_class == TraceFunction: return ( (cur_module.kind == 'truediv' or cur_module.func_type in ('div', 'div_')) and len(cur_module.prev_tensors) == 1 and torch.is_floating_point(cur_module.prev_tensors[0]) and cur_module.func_type != '__rtruediv__' and torch.is_floating_point(node.next_tensors[0]) and node.prev_nodes[0].kind() not in ('size', 'shape') and self.layerwise_config.get(node.unique_name, True) ) div_nodes = graph.filter_forward_nodes(_is_div_node) log.info(f'rewriting div for {[node.unique_name for node in div_nodes]}') for idx, node in enumerate(div_nodes): op_type = node.module.func_type inplace = op_type in ('__itruediv__', 'itruediv', 'div_') if inplace: node.module.func_type = '__imul__' else: node.module.func_type = '__mul__' node.module.kind = 'mul' full_name_parts = node.module.full_name.split('.') full_name_parts[-1] = node.module.func_type node.module.full_name = '.'.join(full_name_parts) # Here we make a simple guess, if dot is in the string, then it's a floating number. # Otherwise, it is an integral number. if '.' in node.module.args_string_no_self: other_arg = float(node.module.args_string_no_self) else: other_arg = int(node.module.args_string_no_self) with override_current_trace_graph(graph): node.module.parse_args(node.prev_tensors[0], 1.0 / other_arg) def _is_sub_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind == 'sub' and torch.is_floating_point(cur_module.prev_tensors[0]) and torch.is_floating_point(node.next_tensors[0]) and self.layerwise_config.get(node.unique_name, True) ) sub_nodes = graph.filter_forward_nodes(_is_sub_node) log.info(f'rewriting sub for {[node.unique_name for node in sub_nodes]}') for idx, node in enumerate(sub_nodes): op_type = node.module.func_type full_name_parts = node.module.full_name.split('.') full_name_parts[-1] = node.module.func_type if len(node.module.prev_tensors) == 1 and node.module.func_type != '__rsub__': node.module.func_type = '__add__' node.module.kind = 'add' node.module.full_name = '.'.join(full_name_parts) # Here we make a simple guess, if dot is in the string, then it's a floating number. # Otherwise, it is an integral number. if '.' in node.module.args_string_no_self or 'e' in node.module.args_string_no_self: other_arg = float(node.module.args_string_no_self) else: other_arg = int(node.module.args_string_no_self) with override_current_trace_graph(graph): node.module.parse_args(node.prev_tensors[0], -other_arg) elif len(node.module.prev_tensors) == 2 and len(node.prev_nodes) == 2: new_fullname_parts = copy.deepcopy(full_name_parts) new_fullname_parts[-1] = '__mul__' new_fullname = '.'.join(new_fullname_parts) current_tensor = node.prev_tensors[0] input_node = node.prev_nodes[1] input_tensor = node.prev_tensors[1] output_tensor = input_tensor * -1 with override_current_trace_graph(graph): trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1) graph.insert_between(input_node, node, trace_func, [output_tensor], True) node.module.func_type = '__add__' node.module.kind = 'add' full_name_parts[-1] = node.module.func_type node.module.full_name = '.'.join(full_name_parts) node.module.parse_args(current_tensor, output_tensor) elif node.module.func_type == '__rsub__' and len(node.module.prev_tensors) == 1: new_fullname_parts = copy.deepcopy(full_name_parts) new_fullname_parts[-1] = '__mul__' new_fullname = '.'.join(new_fullname_parts) input_node = node.prev_nodes[0] input_tensor = node.prev_tensors[0] output_tensor = input_tensor * -1 if '.' in node.module.args_string_no_self: other_arg = float(node.module.args_string_no_self) else: other_arg = int(node.module.args_string_no_self) with override_current_trace_graph(graph): trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1) graph.insert_between(input_node, node, trace_func, [output_tensor], True) node.module.func_type = '__radd__' node.module.kind = 'add' full_name_parts[-1] = node.module.func_type node.module.full_name = '.'.join(full_name_parts) # Then, we write torch.stack nodes to torch.unsqueeze + torch.cat def _is_stack_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) return ( cur_class == TraceFunction and cur_module.kind == 'stack' and torch.is_floating_point(node.next_tensors[0]) and self.layerwise_config.get(node.unique_name, True) ) stack_nodes = graph.filter_forward_nodes(_is_stack_node) for idx, node in enumerate(stack_nodes): args = getattr(node.module, 'args_string_no_self', '') if ',' in args: log.error('rewrite doesn\'t support multiple args for torch.stack') assert False if len(args) > 0: if '=' in args: k, v = args.split('=') assert k in ('axis', 'dim') dim = int(v) if k == 'axis': node.module.args_template_no_self = node.module.args_template_no_self.replace('axis', 'dim') node.module.args_template = node.module.args_template.replace('axis', 'dim') node.module.update_args_string() else: dim = int(args) else: dim = 0 unique_prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values() for n in unique_prev_nodes: shared_tensors = list(set(node.prev_tensors).intersection(set(n.next_tensors))) if len(shared_tensors) == 0: log.debug('tensor rewrite already done, skipping') continue for t in shared_tensors: with override_current_trace_graph(graph): trace_func = TraceFunction('torch.unsqueeze', prefix='fuse_').parse_args(t, dim) next_tensors = [torch.unsqueeze(t, dim)] graph.insert_between(n, node, trace_func, next_tensors, move_idx=True, tensor_ptrs=set([id(t)])) node.module.func_type = 'cat' node.module.kind = 'cat' full_name_parts = node.module.full_name.split('.') full_name_parts[-1] = node.module.func_type node.module.full_name = '.'.join(full_name_parts) # Next, we rewrite add/mul/cat with one float32 output using torch.nn.quantized.FloatFunctional def _is_convertible_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind in ('add', 'mul', 'cat') and torch.is_floating_point(node.next_tensors[0]) and self.layerwise_config.get(node.unique_name, True) and all((torch.is_floating_point(pt) for pt in node.prev_tensors)) ) convertible_nodes = graph.filter_forward_nodes(_is_convertible_node) unfusable_add_nodes = [] # noqa: F841 log.info(f'rewriting add/mul/cat for {[node.unique_name for node in convertible_nodes]}') for idx, node in enumerate(convertible_nodes): old_full_name = node.module.full_name # noqa: F841 old_is_class = node.module.is_class # noqa: F841 op_kind = node.module.kind float_functional = nnq.FloatFunctional() module_name = f'float_functional_simple_{idx}' graph.module_unique_name_dict[id(float_functional)] = module_name graph.module_original_name_dict[id(float_functional)] = module_name float_functional_cls = type(float_functional) module_constructor_lines[id(float_functional)] = f'{qualified_name(float_functional_cls, short=True)}()' new_node = TraceNode(float_functional, cur_graph=graph) graph.nodes_map[new_node.unique_name] = new_node graph.other_init_nodes.append(new_node) node.module.is_class = False prev_tensor_size = len(node.prev_tensors) if op_kind in ('add', 'mul'): if prev_tensor_size == 2: if node.prev_nodes[1].prev_nodes[0].kind() not in ('shape', 'size'): op_type = op_kind else: op_type = f'{op_kind}_scalar' elif prev_tensor_size == 1: op_type = f'{op_kind}_scalar' else: log.error(f'Unknown add/mul type for {node.unique_name}, prev tensor size: {prev_tensor_size}') assert False else: # Don't check anything for other OPs. # It is simply too complex for us. op_type = op_kind node.module.func_type = op_type node.module.full_name = f'self.{module_name}.{op_type}' # Inplace operations if node.module.func_type in ['__iadd__', '__imul__', 'add_', 'mul_']: node.module.add_alias(node.module.tensor_names[0]) q = queue.Queue() q.put(node.prev_nodes[0]) while not q.empty(): n = q.get() if type(n.module) is TraceFunction: prev_aliases = n.module.get_aliases() if prev_aliases is not None: for pa in reversed(prev_aliases): node.module.add_alias(pa, head=True) else: if getattr(n.module, 'inplace', False): q.put(n.prev_nodes[0]) node.module.add_alias(n.prev_node_unique_name(0), head=True) # We need to convert radd to normal add here if node.module.func_type in ['__radd__', '__rmul__']: if '=' in node.module.args_string_no_self or ', ' in node.module.args_string_no_self: log.error(f'Don\'t know how to translate {node.module.args_string_no_self} for __radd__/__rmul__') assert False if prev_tensor_size == 1: # Here we make a simple guess, if dot is in the string, then it's a floating number. # Otherwise, it is an integral number. if '.' in node.module.args_string_no_self or 'e' in node.module.args_string_no_self: other_arg = float(node.module.args_string_no_self) else: other_arg = int(node.module.args_string_no_self) with override_current_trace_graph(graph): node.module.parse_args(node.prev_tensors[0], other_arg) else: with override_current_trace_graph(graph): # It is even simple here. We only need to swap the order of the tensors. node.module.parse_args(node.prev_tensors[1], node.prev_tensors[0]) # Rewrite torch.clamp_{min,max} to torch.clamp def _is_clamp_min_max_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind in ('clamp_min', 'clamp_max') and torch.is_floating_point(node.next_tensors[0]) and self.layerwise_config.get(node.unique_name, True) ) clamp_min_max_nodes = graph.filter_forward_nodes(_is_clamp_min_max_node) log.info(f'rewriting clamp_{{min,max}} for {[node.unique_name for node in convertible_nodes]}') for node in clamp_min_max_nodes: kw = node.module.kind[-3:] _parse_args = eval(f'lambda {kw}, *args, **kwargs: {kw}') val = eval(f'_parse_args({node.module.args_string_no_self})') if kw != 'min' or val != 0.0: continue if kw == 'min': arg_str = f'{val}' else: arg_str = f'None, {val}' sub_str = f'_{kw}' node.module.kind = node.module.kind.replace(sub_str, '') node.module.func_type = node.module.func_type.replace(sub_str, '') node.module.full_name = '.'.join(node.module.full_name.split('.')[:-1] + [node.module.func_type]) node.module.args_template_no_self = arg_str node.module.args_template = f'{{}}, {arg_str}' node.module.update_args_string() # Rewrite torch.clamp to relu, relu6 and clamp_{min, max} def _is_clamp_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind == 'clamp' and torch.is_floating_point(node.next_tensors[0]) and self.layerwise_config.get(node.unique_name, True) ) clamp_nodes = graph.filter_forward_nodes(_is_clamp_node) log.info(f'rewriting clamp for {[node.unique_name for node in convertible_nodes]}') for node in clamp_nodes: def _parse_args(min=None, max=None, *args, **kwargs): return min, max min, max = eval(f'_parse_args({node.module.args_string_no_self})') kind = None if min == 0.0: if max is None: kind = 'relu' elif max == 6.0: kind = 'relu6' if kind is None: if max is None: kind = 'clamp_min' elif min is None: kind = 'clamp_max' else: kind = 'clamp_with_fusion' if kind in ('clamp_min', 'clamp_max'): node.module.kind = kind node.module.func_type = node.module.func_type.replace('clamp', kind) node.module.full_name = '.'.join(node.module.full_name.split('.')[:-1] + [node.module.func_type]) if max is None: arg_str = f'{min}' else: arg_str = f'{max}' elif kind == 'clamp_with_fusion': node.module.is_class = False node.module.kind = kind node.module.func_type = node.module.func_type.replace('clamp', kind) node.module.full_name = f'tinynn.graph.quantization.utils.{node.module.func_type}' arg_str = f'{min}, {max}' else: inplace = node.module.func_type == f'{node.module.kind}_' node.module.kind = kind node.module.func_type = kind node.module.full_name = f'torch.nn.functional.{kind}' if inplace: arg_str = 'inplace=True' else: arg_str = '' node.module.args_template_no_self = arg_str node.module.args_template = f'{{}}, {arg_str}' node.module.update_args_string() # Rewrite other fusable functions # e.g. add_relu(x, y) => # r = torch.add(x, y) # r = torch.nn.functional.relu(r) def _is_add_relu_fusable_node(node: TraceNode, custom_data) -> bool: cur_module = node.module cur_class = type(cur_module) visited_nodes = [node] if self.layerwise_config.get(node.unique_name, True) is False: return False if cur_class == TraceFunction: # The intermediate result cannot be used if fused. # So we don't fuse the nodes under such circumstances. if cur_module.kind != 'add' or len(node.next_nodes) != 1: return False # The input for the add operations should be two tensors. if len(node.prev_tensors) != 2: return False # We accept inplace operations for both add and relu. # The inplace property could be elinimated because we track the tensors # instead of their names. next_node = node.next_nodes[0] next_module = next_node.module next_class = type(next_module) if next_class == TraceFunction: fuse = next_module.kind == 'relu' else: while next_class == nn.Identity: cur_node = next_node visited_nodes.append(cur_node) if len(cur_node.next_nodes) != 1: return False next_node = cur_node.next_nodes[0] next_module = next_node.module next_class = type(next_module) fuse = next_class.__name__ == 'ReLU' if not fuse: return False if type(next_node.module) is TraceFunction: inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string else: inplace = getattr(next_node.module, 'inplace', False) # Inplace check # If add is inplace and relu is not inplace, we need to ensure that all the aliases of # the first operand of add are not used when relu is called. if not inplace: aliases = cur_module.get_aliases() if aliases: q = queue.Queue() q.put(node.prev_nodes[0]) while not q.empty(): n = q.get() last_order = max((x.forward_order for x in n.next_nodes)) if last_order > node.forward_order: fuse = False break if type(n.module) is TraceFunction and n.module.get_aliases(): q.put(n.prev_nodes[0]) elif getattr(n.module, 'inplace', False): q.put(n.prev_nodes[0]) return fuse add_relu_fusable_nodes = graph.filter_forward_nodes(_is_add_relu_fusable_node) for node in add_relu_fusable_nodes: full_name = node.module.full_name.replace('add', 'add_relu') next_node = node.next_nodes[0] kind = 'add_relu' func_type = kind is_class = False nodes_to_fuse = [node, next_node] while next_node.type() is nn.Identity: next_node = next_node.next_nodes[0] nodes_to_fuse.append(next_node) if type(next_node.module) is TraceFunction: inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string else: inplace = next_node.module.inplace graph.fuse_nodes_to_func(nodes_to_fuse, full_name, kind, func_type, is_class) # Propagate aliases for inplace nodes if inplace: aliases = node.module.get_aliases() if aliases: node.module.add_alias(node.module.tensor_names[0]) else: node.module.aliases = None # Rewrite relu, relu6 as nn.ReLU() and nn.ReLU6() for Module fusable rules def _is_functional_rewrite_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return cur_module.kind in FUNCTIONAL_MODULE_MAPPING and self.layerwise_config.get( node.unique_name, True ) func_nodes_to_rewrite = graph.filter_forward_nodes(_is_functional_rewrite_node) log.info(f'rewriting functional to module for {[node.unique_name for node in func_nodes_to_rewrite]}') for idx, node in enumerate(func_nodes_to_rewrite): kind = node.module.kind inplace = node.module.func_type == f'{kind}_' or 'True' in node.module.args_string klass = FUNCTIONAL_MODULE_MAPPING[kind] arguments = getattr(klass, '__constants__', None) if arguments is None: new_func = klass() elif node.module.kind in ('relu', 'relu6', 'silu', 'hardswish'): new_func = klass(inplace=inplace) elif node.module.kind in ('elu', 'leaky_relu'): if hasattr(node.module, 'args_string_no_self'): def _parse_args(alpha=1.0, *args, **kwargs): # noqa: F811 return alpha alpha = eval(f'_parse_args({node.module.args_string_no_self})') if node.module.kind == 'leaky_relu': new_func = klass(alpha, inplace=inplace) else: new_func = klass(alpha, inplace=inplace) else: alpha = None if node.module.kind == 'leaky_relu': new_func = klass(inplace=inplace) else: new_func = klass() elif node.module.kind == 'prelu': weight_t = node.prev_tensors[1] weight_node = node.prev_nodes[1] if weight_node.type() is torch_q.QuantStub: weight_node = weight_node.prev_nodes[0] if weight_node.type() is not ConstantNode or not weight_node.module.is_parameter: log.warning('Rewrite for F.prelu(x, buffer) to nn.PReLU is skipped as it changes the semantics') continue num_parameters = weight_t.nelement() new_func = klass(num_parameters) new_func.weight.data.copy_(node.prev_tensors[1]) # Drop last input node.prev_tensors.pop() last_node = node.prev_nodes.pop() last_node.next_nodes.remove(node) # Remove unused forward functions iteratively new_last_node = last_node while new_last_node is not None: last_node = new_last_node if ( len(last_node.next_nodes) == 0 and len(last_node.prev_nodes) < 2 and last_node in graph.forward_nodes ): if len(last_node.prev_nodes) == 1: new_last_node = last_node.prev_nodes[0] else: new_last_node = None graph.remove_node(last_node) else: new_last_node = None elif node.module.kind == 'glu': if hasattr(node.module, 'args_string_no_self'): def _parse_args_dim(dim=-1, *args, **kwargs): # noqa: F811 return dim dim = eval(f'_parse_args_dim({node.module.args_string_no_self})') new_func = klass(dim) else: dim = None new_func = klass() else: raise NotImplementedError(f"Don't know how to parse {klass.__name__} with argument {arguments}") graph.module_unique_name_dict[id(new_func)] = f'rewritten_{kind}_{idx}' graph.module_original_name_dict[id(new_func)] = f'rewritten_{kind}_{idx}' relu_cls = type(new_func) if inplace: arg_str = 'inplace=True' else: arg_str = '' if node.module.kind in ('elu', 'leaky_relu') and alpha is not None: if arg_str: arg_str = f'{alpha}, {arg_str}' else: arg_str = f'{alpha}' if node.module.kind == 'prelu': if num_parameters != 1: arg_str = f'{num_parameters}' else: arg_str = '' if node.module.kind == 'glu': if dim is not None and dim != -1: arg_str = f'{dim}' else: arg_str = '' module_constructor_lines[id(new_func)] = f'{qualified_name(relu_cls)}({arg_str})' graph.replace_node_module(node, new_func) # Rewrite dropout as nn.Dropout() for models in training mode def _is_dropout_functional_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return cur_module.kind == 'dropout' and self.layerwise_config.get(node.unique_name, True) def _dropout_args(p=0.5, training=True, inplace=False): return p, training, inplace dropout_nodes_to_rewrite = graph.filter_forward_nodes(_is_dropout_functional_node) log.info(f'rewriting dropout for {[node.unique_name for node in dropout_nodes_to_rewrite]}') for idx, node in enumerate(dropout_nodes_to_rewrite): args = getattr(node.module, 'args_string_no_self', '') p, _, inplace = eval(f'_dropout_args({args})') kind = node.module.kind inplace = node.module.func_type == f'{kind}_' or inplace dropout_cls = nn.Dropout new_dropout = dropout_cls(p, inplace=inplace) graph.module_unique_name_dict[id(new_dropout)] = f'rewritten_{kind}_{idx}' graph.module_original_name_dict[id(new_dropout)] = f'rewritten_{kind}_{idx}' if inplace: arg_str = f'{p}, inplace={inplace}' else: arg_str = f'{p}' module_constructor_lines[id(new_dropout)] = f'{qualified_name(dropout_cls)}({arg_str})' graph.replace_node_module(node, new_dropout) # Add contiguous nodes for partially-supported OPs # Some of the operations support quantization, but they only accept contiguous input tensors. def _is_partially_quantizable(node, custom_data): if self.layerwise_config.get(node.unique_name, True) is False: return False cur_module = node.module cur_class = type(cur_module) if cur_class == ConstantNode: return False elif cur_class == TraceFunction: return cur_module.kind in ('pad',) else: return cur_class in (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d, nn.ZeroPad2d) if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): partially_supported_nodes = graph.filter_forward_nodes(_is_partially_quantizable) for idx, node in enumerate(partially_supported_nodes): for n in node.prev_nodes[:1]: shared_tensors = list(set(node.prev_tensors).intersection(set(n.next_tensors))) if len(shared_tensors) > 1: log.error('rewrite for partially-supported ops supports with nodes with exact one input') assert False with override_current_trace_graph(graph): trace_func = TraceFunction('torch.Tensor.contiguous', True, prefix='fuse_').parse_args( shared_tensors[0] ) next_tensors = [x.contiguous() for x in shared_tensors] graph.insert_between(n, node, trace_func, next_tensors, True) # Remove non-leaf `.data` nodes def _is_non_leaf_data_nodes(node, custom_data): if self.layerwise_config.get(node.unique_name, True) is False: return False cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return cur_module.kind == 'data' and cur_module.is_property and len(node.next_nodes) > 0 return False non_leaf_data_nodes = graph.filter_forward_nodes(_is_non_leaf_data_nodes) for idx, node in enumerate(non_leaf_data_nodes): graph.remove_node(node) # Handle PoolNd with kernel_size=1 def _is_pool_nd_with_one_kernel_size(node, custom_data): if self.layerwise_config.get(node.unique_name, True) is False: return False cur_module = node.module cur_class = type(cur_module) kernel_size, stride = None, None if cur_class == TraceFunction: if cur_module.kind in ('avg_pool1d', 'avg_pool2d', 'max_pool1d', 'max_pool2d'): def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs): return kernel_size, stride kernel_size, stride = eval(f'_avgpool_kernel_size_and_stride({cur_module.args_string_no_self})') else: if cur_class in (nn.AvgPool1d, nn.AvgPool2d, nn.MaxPool1d, nn.MaxPool2d): kernel_size = cur_module.kernel_size stride = cur_module.stride if kernel_size is not None: if isinstance(kernel_size, (tuple, list)): is_match = all((ks == 1 for ks in kernel_size)) else: is_match = kernel_size == 1 else: is_match = False if is_match: custom_data.append((node, kernel_size, stride)) return True else: return False pool_one_kernel_size_nodes = [] graph.filter_forward_nodes(_is_pool_nd_with_one_kernel_size, pool_one_kernel_size_nodes) for idx, (node, kernel_size, stride) in enumerate(pool_one_kernel_size_nodes): slices = [slice(None)] * 2 t = node.prev_tensors[0] dim = len(t.shape) if not isinstance(stride, (list, tuple)): stride = [stride] * (dim - 2) for s in stride: if s == 1 or s is None: slices.append(slice(None)) else: slices.append(slice(None, None, s)) with override_current_trace_graph(graph): new_func = TraceFunction('torch.Tensor.__getitem__', True).parse_args(t, slices) graph.module_unique_name_dict[id(new_func)] = f'rewritten_pool_{idx}' graph.module_original_name_dict[id(new_func)] = f'rewritten_pool_{idx}' graph.replace_node_module(node, new_func) # Rewrite Linear-BatchNorm1d structure to Conv2d-BatchNorm2d is_rewrite_to_fuse = functools.partial( self.is_fusable, current_rules=load_processed_rewrite_to_fuse_rules(), check_node_quantized=False, graph=graph, layerwise_config_default=True, use_original_name=False, ) custom_data = ([], set()) graph.filter_forward_nodes(is_rewrite_to_fuse, custom_data, reverse=True) rewrite_fuse_names_list = custom_data[0] log.debug(f'found names_list that need to rewrite for fusing: {rewrite_fuse_names_list}') for idx, names in enumerate(reversed(rewrite_fuse_names_list)): # case fc-bn1d assert len(names) == 2, 'the rewrite nodes list length != 2' node_fc = graph.nodes_map[names[0]] node_bn1d = graph.nodes_map[names[1]] mod_fc = node_fc.module mod_bn = node_bn1d.module assert type(mod_fc) is nn.Linear and type(mod_bn) is nn.BatchNorm1d, "the rewrite struct is\'t [fc-bn1d]" if len(node_fc.prev_tensors[0].shape) != 2: log.debug('the [fc-bn]\'s input dimension != 2') continue # for fc-bn1d, rewrite [fc-bn1d] to [conv2d-bn2d] new_conv2d = torch.nn.Conv2d( in_channels=mod_fc.in_features, out_channels=mod_fc.out_features, kernel_size=[1, 1], bias=mod_fc.bias is not None, ) fc_weight = mod_fc.weight new_conv2d.weight = nn.Parameter(torch.reshape(fc_weight, [fc_weight.shape[0], fc_weight.shape[1], 1, 1])) if mod_fc.bias is not None: new_conv2d.bias = mod_fc.bias graph.module_unique_name_dict[id(new_conv2d)] = f'rewritten_conv2d_bn2d_conv2d_{idx}' graph.module_original_name_dict[id(new_conv2d)] = f'rewritten_conv2d_bn2d_conv2d_{idx}' new_bn2d = torch.nn.BatchNorm2d( mod_bn.num_features, mod_bn.eps, mod_bn.momentum, affine=mod_bn.affine, track_running_stats=mod_bn.track_running_stats, ) new_bn2d.load_state_dict(mod_bn.state_dict()) graph.module_unique_name_dict[id(new_bn2d)] = f'rewritten_conv2d_bn2d_bn2d_{idx}' graph.module_original_name_dict[id(new_bn2d)] = f'rewritten_conv2d_bn2d_bn2d_{idx}' # replace new node, then insert reshape before new_conv2d and after new_bn2d with override_current_trace_graph(graph): graph.replace_node_module(node_fc, new_conv2d) graph.replace_node_module(node_bn1d, new_bn2d) prev_func = TraceFunction('torch.Tensor.__getitem__', prefix='rewritten_conv2d_bn2d_').parse_args( node_fc.prev_tensors[0], (Ellipsis, None, None) ) next_func = TraceFunction('torch.flatten', prefix='rewritten_conv2d_bn2d_').parse_args( node_bn1d.next_tensors[0], 1 ) # expand the tensor shape between fc new_conv2d and new_bn2d node_fc.next_tensors[0].unsqueeze_(2).unsqueeze_(2) node_bn1d.prev_tensors[0].unsqueeze_(2).unsqueeze_(2) node_bn1d.next_tensors[0].unsqueeze_(2).unsqueeze_(2) prev_out = node_fc.prev_tensors[0][..., None, None] graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out], True) next_out = torch.flatten(node_bn1d.next_tensors[0], 1) graph.insert_after(node_bn1d, next_func, [next_out]) # Rewrite BatchNorm1d to BatchNorm2d def _is_batch_norm_1d(node, custom_data): if self.layerwise_config.get(node.unique_name, True) is False: return False cur_module = node.module cur_class = type(cur_module) if len(node.prev_nodes) != 1: return False if node.prev_nodes[0].kind() in ('conv1d', nn.Conv1d): return False if cur_class == TraceFunction: return cur_module.kind == 'batch_norm' and node.prev_tensors[0].ndim == 3 else: return cur_class == nn.BatchNorm1d batch_norm_1d_nodes = graph.filter_forward_nodes(_is_batch_norm_1d) for idx, node in enumerate(batch_norm_1d_nodes): mod = node.module if type(mod) is nn.BatchNorm1d: new_bn = torch.nn.BatchNorm2d( mod.num_features, mod.eps, mod.momentum, affine=mod.affine, track_running_stats=mod.track_running_stats, ) new_bn.load_state_dict(mod.state_dict()) graph.module_unique_name_dict[id(new_bn)] = f'rewritten_bn2d_{idx}' graph.module_original_name_dict[id(new_bn)] = f'rewritten_bn2d_{idx}' with override_current_trace_graph(graph): graph.replace_node_module(node, new_bn) prev_func = TraceFunction('torch.unsqueeze', prefix='rewritten_bn2d_').parse_args( node.prev_tensors[0], 2 ) next_func = TraceFunction('torch.squeeze', prefix='rewritten_bn2d_').parse_args( node.next_tensors[0], 2 ) node.next_tensors[0].unsqueeze_(2) prev_out = torch.unsqueeze(node.prev_tensors[0], 2) graph.insert_between(node.prev_nodes[0], node, prev_func, [prev_out], True) next_out = torch.squeeze(node.next_tensors[0], 2) graph.insert_after(node, next_func, [next_out]) type_dict = {} for n in graph.forward_nodes: if n.type() not in (torch_q.QuantStub, torch_q.DeQuantStub): self.layerwise_config.setdefault(n.unique_name, True) kind = n.kind() type_str = kind if isinstance(kind, str) else kind.__name__ type_dict[n.unique_name] = type_str fuse_mapping = {} if self.fused_layerwise_config: # Find all fusable nodes if type(self).__name__ == 'QATQuantizer': processed_all_rules = load_processed_all_qat_rules() else: processed_all_rules = load_processed_all_ptq_rules() is_fusable = functools.partial( self.is_fusable, current_rules=processed_all_rules, check_node_quantized=False, use_original_name=False, ) custom_data = ([], set()) graph.filter_forward_nodes(is_fusable, custom_data, reverse=True) activ_names = custom_data[0] for idx, names in enumerate(reversed(activ_names)): types = [graph.nodes_map[n].kind() for n in names] types = [x.__name__ if not isinstance(x, str) else x for x in types] types_str = ', '.join(types) types_str = f'({types_str})' type_dict[names[0]] = types_str for name in names[1:]: fuse_mapping[name] = names[0] self.layerwise_config[name] = self.layerwise_config.get(names[0], True) for n, t in type_dict.items(): self.layerwise_config.yaml_add_eol_comment(f'type: {t}', n) skip_types = set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) == 1) for module_cls, action in self.quantize_op_action.items(): if action == 'rewrite': skip_types.add(module_cls) if self.set_quantizable_op_stats: skip_types |= set(KNOWN_QSTATS.keys()) skip_types_prev = skip_types | set(k[-1] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1) skip_types_next = skip_types | set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1) # Add quant/dequant nodes for non-quantizable OPs disable_quantize_op_list = UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.copy() for module_cls, action in self.quantize_op_action.items(): if action in ('disable', 'rewrite'): disable_quantize_op_list[module_cls] = None def _is_rewritable_lstm_node(node, custom_data): cur_module = node.module cur_class = type(cur_module) return cur_class == nn.LSTM if self.quantize_op_action.get(nn.LSTM, 'enable') == 'rewrite': rewritable_lstm_nodes = graph.filter_forward_nodes(_is_rewritable_lstm_node) fake_dequant_cls = torch_q.DeQuantStub for idx, node in enumerate(rewritable_lstm_nodes): assert node.module.num_layers == 1, "Quantization rewrite for multi-layer LSTM is not yet supported" assert not node.module.bidirectional, "Quantization rewrite for bidirectional LSTM is not yet supported" cell_state = node.next_tensors[1][1] fake_dequant = fake_dequant_cls() fake_dequant_name = f'fake_dequant_rewrite_{idx}' graph.module_unique_name_dict[id(fake_dequant)] = fake_dequant_name graph.module_original_name_dict[id(fake_dequant)] = fake_dequant_name module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()' new_node = graph.insert_new_after( node, fake_dequant, [cell_state], [[1, 1]], before_node=node.next_nodes[0] ) with override_current_trace_graph(graph): size_func = TraceFunction( 'torch.Tensor.size', is_class=True, prefix='fake_dequant_rewrite_' ).parse_args(new_node.next_tensors[0], -1) size_node = graph.insert_new_after( new_node, size_func, [new_node.next_tensors[0]], [None], next_tensors=[torch.tensor(new_node.next_tensors[0].size(-1))], before_node=node.next_nodes[0], ) size_len = len(node.next_tensors[0].shape) if node.module.bidirectional: with override_current_trace_graph(graph): size_func = TraceFunction( 'torch.Tensor.__mul__', is_class=True, prefix='fake_dequant_rewrite_' ).parse_args(size_node.next_tensors[0], 2) size_node = graph.insert_new_after( size_node, size_func, [size_node.next_tensors[0]], [None], next_tensors=[size_node.next_tensors[0] * 2], before_node=node.next_nodes[0], ) with override_current_trace_graph(graph): expand_func = TraceFunction( 'torch.Tensor.expand', is_class=True, prefix='fake_dequant_rewrite_' ).parse_args(node.next_tensors[0], *((-1,) * (size_len - 1)), size_node.next_tensors[0]) graph.insert_between( node, node.next_nodes[0], expand_func, tensor_ptrs=[id(node.next_tensors[0])], move_idx=True ) expand_node = graph.nodes_map[expand_func.unique_name] size_node.next_nodes.append(expand_node) expand_node.prev_nodes.append(node) expand_node.prev_tensors.append(size_node.next_tensors[0]) expand_node.prev_indices.append(None) def _is_not_quantizable(node, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == ConstantNode: return False elif cur_class == TraceFunction: if node.type() in ('__truediv__', '__itruediv__', 'div', 'div_'): if node.prev_nodes[0].kind() in ('shape', 'size'): return False if node.type() in ('shape', 'device', 'size', 'dtype'): if node.unique_name in self.layerwise_config: del self.layerwise_config[node.unique_name] return False if self.layerwise_config.get(node.unique_name, True) is False: return True supported_version = disable_quantize_op_list.get(cur_module.kind, torch.__version__) return supported_version is None or LooseVersion(torch.__version__) < supported_version else: if isinstance(cur_module, (torch_q.QuantStub, torch_q.DeQuantStub)): return False if self.layerwise_config.get(node.unique_name, True) is False: return True unsupported_types = tuple( k for k, v in disable_quantize_op_list.items() if type(k) is not str and k not in Q_MODULES_MAPPING and (v is None or LooseVersion(torch.__version__) < v) ) return isinstance(cur_module, unsupported_types) unsupported_nodes = graph.filter_forward_nodes(_is_not_quantizable) for idx, node in enumerate(reversed(unsupported_nodes)): if node.unique_name in self.layerwise_config: if node.kind() not in skip_types and self.layerwise_config[node.unique_name]: del self.layerwise_config[node.unique_name] node_map = dict() next_nodes = {n.unique_name: n for n in node.next_nodes}.values() for inner_idx, next_node in enumerate(next_nodes): prev_tensor_ptrs = [] if type(next_node.module) is TraceFunction and next_node.module.is_property: continue for pt in next_node.prev_tensors: for nt in node.next_tensors: if isinstance(nt, (list, tuple)): for k, ntt in enumerate(nt): if id(pt) == id(ntt): if id(pt) not in prev_tensor_ptrs: prev_tensor_ptrs.append(id(pt)) break elif id(pt) == id(nt): if id(pt) not in prev_tensor_ptrs: prev_tensor_ptrs.append(id(pt)) break for ptr_idx, ptr in enumerate(prev_tensor_ptrs): if ptr in node_map: fake_quant = node_map[ptr] graph.insert_between(node, next_node, fake_quant, move_idx=True, tensor_ptrs=[ptr]) else: fake_quant = torch_q.QuantStub() fake_quant_name = f'fake_quant_inner_{idx}_{inner_idx}_{ptr_idx}' graph.module_unique_name_dict[id(fake_quant)] = fake_quant_name graph.module_original_name_dict[id(fake_quant)] = fake_quant_name fake_quant_cls = type(fake_quant) module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()' graph.insert_between(node, next_node, fake_quant, move_idx=True, tensor_ptrs=[ptr]) node_map[ptr] = graph.nodes_map[fake_quant_name] # Insert the DeQuantStub nodes before every input node of the unsupported ops for idx, node in enumerate(unsupported_nodes): fake_dequant_cls = torch_q.DeQuantStub assert node.rev_index is False node_map = dict() prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values() for inner_idx, prev_node in enumerate(prev_nodes): if prev_node.kind() in ('shape', 'device', 'size', 'dtype'): continue prev_tensor_ptrs = [] for pt in node.prev_tensors: for nt in prev_node.next_tensors: if isinstance(nt, (list, tuple)): for k, ntt in enumerate(nt): if id(pt) == id(ntt): if id(pt) not in prev_tensor_ptrs: prev_tensor_ptrs.append(id(pt)) break elif id(pt) == id(nt): if id(pt) not in prev_tensor_ptrs: prev_tensor_ptrs.append(id(pt)) break for ptr_idx, ptr in enumerate(prev_tensor_ptrs): if ptr in node_map: fake_dequant = node_map[ptr] graph.insert_between(prev_node, node, fake_dequant, move_idx=True, tensor_ptrs=set([ptr])) else: fake_dequant = fake_dequant_cls() fake_dequant_name = f'fake_dequant_inner_{idx}_{inner_idx}_{ptr_idx}' graph.module_unique_name_dict[id(fake_dequant)] = fake_dequant_name graph.module_original_name_dict[id(fake_dequant)] = fake_dequant_name module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()' graph.insert_between(prev_node, node, fake_dequant, move_idx=True, tensor_ptrs=set([ptr])) node_map[ptr] = graph.nodes_map[fake_dequant_name] # Remove consecutive dequant quant nodes def _is_consecutive_dequant_quant_nodes(node, custom_data): cur_type = node.type() if cur_type in (torch_q.QuantStub, torch_q.DeQuantStub): for next_node in node.next_nodes: next_type = next_node.type() if next_type in (torch_q.QuantStub, torch_q.DeQuantStub): if cur_type != next_type: if cur_type == torch_q.QuantStub: target_node = None if len(node.prev_nodes) == 1 and node.prev_nodes[0].kind() in skip_types_prev: target_node = node.prev_nodes[0] elif ( len(next_node.next_nodes) == 1 and next_node.next_nodes[0].kind() in skip_types_next ): target_node = next_node.next_nodes[0] if target_node is not None: self.layerwise_config.setdefault(target_node.unique_name, True) if self.layerwise_config[target_node.unique_name]: return False custom_data.append((node, next_node)) return True return False consecutive_dequant_quant_nodes = [] consecutive_branch_dequant_quant_nodes = [] graph.filter_forward_nodes(_is_consecutive_dequant_quant_nodes, consecutive_dequant_quant_nodes) for node, next_node in consecutive_dequant_quant_nodes: if len(node.next_nodes) == 1 and len(next_node.prev_nodes) == 1: graph.remove_node(next_node) graph.remove_node(node) else: consecutive_branch_dequant_quant_nodes.append((node, next_node)) # TODO: Support complex cases for node, next_node in consecutive_branch_dequant_quant_nodes: is_removable = all( ( n.type() in (torch_q.QuantStub, torch_q.DeQuantStub, 'shape', 'device', 'size', 'dtype') and n.type() != node.type() for n in node.next_nodes ) ) if is_removable: graph.remove_node(node) for n in node.next_nodes: if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub): graph.remove_node(n) # Process additional fusable nodes processed_extra_q_rules = load_processed_extra_q_rules() is_extra_fusable = functools.partial( self.is_fusable, current_rules=processed_extra_q_rules, check_node_quantized=False, use_original_name=False, ) custom_data = ([], set()) graph.filter_forward_nodes(is_extra_fusable, custom_data, reverse=True) activ_names = custom_data[0] log.debug(f'found nodes that cannot fuse: {activ_names}') for idx, names in enumerate(reversed(activ_names)): name = names[-1] node = graph.nodes_map[name] fake_quant = torch_q.QuantStub() graph.module_unique_name_dict[id(fake_quant)] = f'fake_activ_quant_{idx}' graph.module_original_name_dict[id(fake_quant)] = f'fake_activ_quant_{idx}' fake_quant_cls = type(fake_quant) module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()' graph.insert_after(node, fake_quant) fake_dequant = torch_q.DeQuantStub() graph.module_unique_name_dict[id(fake_dequant)] = f'fake_activ_dequant_{idx}' graph.module_original_name_dict[id(fake_dequant)] = f'fake_activ_dequant_{idx}' fake_dequant_cls = type(fake_dequant) module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()' graph.insert_after(node, fake_dequant) fused_clamps = [] for names in activ_names: for name in names: node = graph.nodes_map[name] if node.kind() == 'clamp_with_fusion': fused_clamps.append(name) # Rewrite unfused clamp_with_fusion to torch.clamp def _is_unfused_clamp_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind == 'clamp_with_fusion' and torch.is_floating_point(node.next_tensors[0]) and node.unique_name not in fused_clamps ) return False unfused_clamp_nodes = graph.filter_forward_nodes(_is_unfused_clamp_node) log.info(f'rewriting unfused_clamp for {[node.unique_name for node in unfused_clamp_nodes]}') for node in unfused_clamp_nodes: node.module.kind = kind node.module.func_type = node.module.func_type.replace('clamp_with_fusion', 'clamp') node.module.full_name = f'torch.{node.module.func_type}' # Optional tensor shape broadcasting for quantized binary ops def _is_broadcastable_binary_quantized_op_node(node: TraceNode, custom_data) -> bool: cur_module = node.module cur_class = type(cur_module) if cur_class != TraceFunction: return False return ( cur_module.full_name.startswith('self.float_functional_simple_') and cur_module.func_type in ('add', 'mul', 'add_relu') and node.prev_tensors[0].shape != node.prev_tensors[1].shape ) broadcastable_binary_quantized_op_nodes = graph.filter_forward_nodes(_is_broadcastable_binary_quantized_op_node) for node in broadcastable_binary_quantized_op_nodes: assert len(node.prev_nodes) == 2 assert len(node.prev_tensors) == 2 l_shape = list(node.prev_tensors[0].shape) r_shape = list(node.prev_tensors[1].shape) ref_index = None if len(l_shape) > len(r_shape): ref_index = 0 r_shape = [1] * (len(l_shape) - len(r_shape)) + r_shape elif len(l_shape) < len(r_shape): ref_index = 1 l_shape = [1] * (len(r_shape) - len(l_shape)) + l_shape for l_dim, r_dim in zip(l_shape, r_shape): if l_dim > r_dim: if ref_index in (None, 0) and r_dim == 1: ref_index = 0 else: ref_index = -1 break elif l_dim < r_dim: if ref_index in (None, 1) and l_dim == 1: ref_index = 1 else: ref_index = -1 break if ref_index >= 0: src_index = 1 - ref_index with override_current_trace_graph(graph): trace_func = TraceFunction('torch.Tensor.expand_as', True, prefix='fuse_').parse_args( node.prev_tensors[src_index], node.prev_tensors[ref_index] ) next_tensors = [node.prev_tensors[src_index].expand_as(node.prev_tensors[ref_index])] graph.insert_between(node.prev_nodes[src_index], node, trace_func, next_tensors, True) new_node = graph.nodes_map[trace_func.unique_name] new_node.prev_nodes.append(node.prev_nodes[ref_index]) new_node.prev_tensors.append(node.prev_tensors[ref_index]) new_node.prev_indices.append(node.prev_indices[ref_index]) node.prev_nodes[ref_index].next_nodes.append(new_node) else: new_indices = [] for i in range(2): if node.prev_indices[i] is None: new_indices.append(i) elif isinstance(node.prev_indices[i], (tuple, list)): new_indices.append(node.prev_indices[i] + [i]) else: new_indices.append([node.prev_indices[i], i]) with override_current_trace_graph(graph): trace_func = TraceFunction('torch.broadcast_tensors', False, prefix='fuse_').parse_args( node.prev_tensors[0], node.prev_tensors[1] ) next_tensors = torch.broadcast_tensors(node.prev_tensors[0], node.prev_tensors[1]) graph.insert_before(node, trace_func, next_tensors, False, new_indices) for name in fuse_mapping: if name in self.layerwise_config: del self.layerwise_config[name] for name in self.layerwise_config: node = graph.nodes_map[name] if isinstance(node.module, nn.Module) and not isinstance( node.module, (nn.Dropout, nn.AdaptiveAvgPool2d, nn.AvgPool2d, nn.MaxPool2d, nn.ReLU, nn.Upsample, nn.ConstantPad2d), ): self.effective_layers.append(name) elif node.kind() in ('add', 'mul', 'cat') or node.kind() in FUNCTIONAL_MODULE_MAPPING: self.effective_layers.append(name) graph.quantized = True graph.recompute_forward_order() def is_fusable( self, node, custom_data, current_rules=None, check_node_quantized=True, use_original_name=True, graph=None, layerwise_config_default=False, ): # Tell whether a TraceNode is fusable with some nearby nodes if current_rules is None: return False cur_node = node names = [] final_names = [] current_state = False while True: if check_node_quantized: if not cur_node.quantized: break else: if self.layerwise_config.get(cur_node.unique_name, layerwise_config_default) is False: break cur_module = cur_node.module cur_class = type(cur_module) if isinstance(cur_module, TraceFunction): cur_class = cur_module.kind prev_nodes = cur_node.prev_nodes log.debug(f'cur: {cur_class}') if cur_class in current_rules or cur_class == nn.Identity: if use_original_name and isinstance(cur_module, nn.Module): cur_name = graph.module_original_name_dict[id(cur_module)] else: cur_name = cur_node.unique_name if cur_name in custom_data[1]: log.debug('found existing nodes, skipping') break if len(prev_nodes) == 0: break if cur_class in current_rules: current_state, current_rules = current_rules[cur_class] log.debug(f'dict: {current_rules}, {current_state}') names.append(cur_name) if current_state is True: log.debug(f'update best: {names}') final_names.clear() final_names.extend(names) if len(cur_node.prev_nodes) != 1: break cur_node = cur_node.prev_nodes[0] else: if cur_class == nn.ConvTranspose2d: if len(final_names) > 0: next_types = [] for _ in range(len(final_names)): next_node = cur_node.next_nodes[0] next_types.append(next_node.kind()) cur_node = next_node if tuple(next_types) == (nn.BatchNorm2d, nn.ReLU): final_names.clear() break if len(final_names) > 0: final_names.reverse() log.debug(f'final: {final_names}') custom_data[0].append(final_names) for name in final_names: custom_data[1].add(name) return True return False def error_analysis(self, qat_model: nn.Module, dummy_input, threshold: float = 20.0): """Generates the QAT error report using the SQNR metric Args: qat_model: The QAT model dummy_input: A viable input to the model threshold (float): The threshold of SQNR. Defaults to 20.0 """ if isinstance(qat_model, DataParallel) or isinstance(qat_model, DistributedDataParallel): model = qat_model.module else: model = qat_model modules_list = {} names_list = {} float_results = {} hooks = [] def forward_hook(module, input, output): name = names_list[module] float_results[name] = input fake_quant_enabled_dict = {} observer_enabled_dict = {} for n, m in model.named_modules(): if isinstance(m, torch.quantization.FakeQuantize): names_list[m] = n modules_list[n] = m fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone() observer_enabled_dict[m] = m.observer_enabled.clone() hooks.append(m.register_forward_hook(forward_hook)) if len(modules_list) == 0: log.warning('No FakeQuantize modules found. Are you sure you passed in a QAT model?') return model.apply(torch.quantization.disable_fake_quant) model.apply(torch.quantization.disable_observer) device = get_module_device(model) if type(dummy_input) is torch.Tensor: actual_input = [dummy_input] elif isinstance(dummy_input, (tuple, list)): actual_input = list(dummy_input) else: log.error(f'Unsupported type {type(dummy_input)} for dummy input') assert False for i in range(len(actual_input)): dummy_input = actual_input[i] if type(dummy_input) is torch.Tensor: if dummy_input.device != device: actual_input[i] = dummy_input.to(device) model.eval() with torch.no_grad(): model(*actual_input) for h in hooks: h.remove() hooks.clear() for m, v in fake_quant_enabled_dict.items(): m.fake_quant_enabled = v def sqnr(x, y): Ps = torch.norm(x) Pn = torch.norm(x - y) return (20 * torch.log10(Ps / Pn)).item() q_errors_weight = [] q_errors_activ = [] while len(float_results) > 0: n, f = float_results.popitem() mod = modules_list[n] with torch.no_grad(): q = mod(*f) loss = sqnr(f[0], q) actual_n = '.'.join(n.split('.')[:-1]) if loss <= threshold: if n.endswith('.weight_fake_quant'): q_errors_weight.append((actual_n, mod, loss)) else: q_errors_activ.append((actual_n, mod, loss)) q_errors_weight = sorted(q_errors_weight, key=lambda x: x[2]) q_errors_activ = sorted(q_errors_activ, key=lambda x: x[2]) logs = [] if len(q_errors_weight) > 0: logs.append('') logs.append(f'Weights (SQNR <= {threshold}):') for n, m, e in q_errors_weight: logs.append(f'{n} SQNR: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}') if len(q_errors_activ) > 0: logs.append('') logs.append(f'Activations (SQNR <= {threshold}):') for n, m, e in q_errors_activ: logs.append(f'{n} SQNR: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}') if len(q_errors_weight) == 0 and len(q_errors_activ) == 0: logs.append('') logs.append('All good!') if len(logs) > 0: logs.insert(0, 'Quantization error report:') logs.append('') full_log = '\n'.join(logs) log.warning(full_log) for m, v in observer_enabled_dict.items(): m.observer_enabled = v model.train() def restore_to_original(self, q_model: nn.Module): """Restores a QAT/PTQ-converted model to original state Args: qat_model: The QAT/PTQ-converted model """ sub_list = [] for n, m in q_model.named_children(): if isinstance(m, (torch.nn.quantized.Quantize, torch.nn.quantized.DeQuantize)): sub_list.append((n, nn.Identity())) elif isinstance(m, torch.nn.quantized.Linear): fc = nn.Linear(m.in_features, m.out_features, m.bias is not None) fc.weight = torch.nn.Parameter(m.weight().dequantize()) fc.bias = torch.nn.Parameter(m.bias()) sub_list.append((n, fc)) elif isinstance(m, torch.nn.quantized.Conv2d): conv = nn.Conv2d( m.in_channels, m.out_channels, m.kernel_size, m.stride, m.padding, m.dilation, m.groups, m.bias is not None, m.padding_mode, ) conv.weight = torch.nn.Parameter(m.weight().dequantize()) conv.bias = torch.nn.Parameter(m.bias()) sub_list.append((n, conv)) elif isinstance(m, torch.nn.quantized.QFunctional): sub_list.append((n, nn.quantized.FloatFunctional())) elif hasattr(torch.nn.quantized, 'ReLU') and isinstance(m, torch.nn.quantized.ReLU): sub_list.append((n, nn.ReLU())) elif isinstance(m, torch.nn.quantized.ReLU6): sub_list.append((n, nn.ReLU6())) elif type(m).__module__.startswith('torch.nn.modules'): continue else: assert False, f"unsupported type: {type(m).__name__}" for n, m in sub_list: setattr(q_model, n, m) def convert(self, q_model: nn.Module, backend: str = 'tflite') -> nn.Module: """Converts a QAT/PTQ-prepared model to an actual quantized model Args: q_model (nn.Module): The QAT/PTQ-prepared model backend (str): The backend to translate for, including `pytorch` and `tflite`. Defaults to `tflite` Returns: nn.Module: The QAT/PTQ-converted model. When the backend is set to `pytorch`, it is used for validation \ in PyTorch only. """ if self.quantize_op_action.get(nn.LSTM, None) and self.backend == 'qnnpack': self.restore_lstm_weights(q_model) for acp, post_acp, dq_name, q_name, activ_name, activ_type in self.extra_qparams_mappings: if backend != 'pytorch' and activ_type in ('relu', 'relu6', torch.nn.ReLU, torch.nn.ReLU6): continue acp.scale = post_acp.scale acp.zero_point = post_acp.zero_point acp.activation_post_process.min_val = post_acp.activation_post_process.min_val acp.activation_post_process.max_val = post_acp.activation_post_process.max_val setattr(q_model, dq_name, nn.Identity()) setattr(q_model, q_name, nn.Identity()) if activ_name is not None: setattr(q_model, activ_name, nn.Identity()) if type(self).__name__ == 'QATQuantizer': q = queue.Queue() q.put(q_model) while not q.empty(): m = q.get() for n, c in m.named_children(): if isinstance(c, ConvTransposeBn2d): setattr(m, n, c.transform(c)) else: q.put(c) if hasattr(torch_q, 'get_default_static_quant_module_mappings'): mapping = torch_q.get_default_static_quant_module_mappings() elif hasattr(torch_q, 'get_static_quant_module_mappings'): mapping = copy.deepcopy(torch_q.get_static_quant_module_mappings()) else: mapping = copy.deepcopy(torch_q.DEFAULT_MODULE_MAPPING) mapping.update(FUSE_QAT_MODULES_CVT) float_mods = {} for qat_t, q_t in FUSE_QAT_MODULES_CVT.items(): float_mod = getattr(q_t, '_FLOAT_MODULE', None) if float_mod is not None: float_mods[q_t] = float_mod setattr(q_t, '_FLOAT_MODULE', qat_t) q_model = torch.quantization.convert(q_model, mapping) for q_t, orig_t in float_mods.items(): setattr(q_t, '_FLOAT_MODULE', orig_t) float_mods.clear() else: q_model = torch.quantization.convert(q_model) return q_model def optimize_conv_bn_fusion(self, q_model, eps=1e-5): """Optimizes the Conv-BatchNorm fusion pattern. Sometimes, the running_var of the BatchNorm could be near to zero. If the weights of those channels happen to be large, then it may lead to large quantization losses. We choose to ignore those channels when `bn.running_var.abs() < eps`. Args: q_model (nn.Module): The QAT/PTQ-prepared model """ def _pre_hook_func(indices): def _pre_hook(mod, input): max_val = input[0][~indices].max() min_val = input[0][~indices].min() input[0].clamp_(min_val, max_val) return input return _pre_hook for m in q_model.modules(): if type(m).__name__ in ('ConvBnReLU2d', 'ConvBn2d'): if m.in_channels == m.out_channels and m.out_channels == m.groups and m.groups > 1: indices = m.bn.running_var < eps if torch.any(indices): m.weight_fake_quant.register_forward_pre_hook(_pre_hook_func(indices)) def prepare_onnx_export(self, q_model): """Prepares for ONNX model export Args: q_model (nn.Module): The QAT/PTQ-prepared model """ if self.backend == 'onnx': for mod in self.leaf_nodes: torch.quantization.disable_fake_quant(mod.activation_post_process) fq_base_cls = getattr(torch_q, 'FakeQuantizeBase', torch_q.FakeQuantize) for n, m in q_model.named_modules(): if isinstance(m, fq_base_cls): if m.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric): m.zero_point.fill_(0) else: assert m.qscheme == torch.per_tensor_affine assert isinstance(m, torch_q.FixedQParamsFakeQuantize) m.qscheme = torch.per_tensor_symmetric if torch.all(m.zero_point == 128): m.zero_point.fill_(0) else: assert not torch.is_nonzero(m.zero_point) m.scale *= 2 m.quant_min = -128 m.quant_max = 127 m.dtype = torch.int8 if hasattr(m, 'activation_post_process'): m.activation_post_process.quant_min = -128 m.activation_post_process.quant_max = 127 m.activation_post_process.dtype = torch.int8 if LooseVersion(torch.__version__) >= LooseVersion('1.12.0'): def get_wrapper(mod): def wrapper(*args, **kwargs): return torch_q.FakeQuantize.forward(mod, *args, **kwargs) return wrapper fused_fq_cls = getattr(torch_q, 'FusedMovingAvgObsFakeQuantize', None) if fused_fq_cls is not None: for m in q_model.modules(): if isinstance(m, fused_fq_cls): m.forward = get_wrapper(m) elif LooseVersion(torch.__version__) >= LooseVersion('1.10.0'): mod_dict = {} for n, m in q_model.named_modules(): mod_dict[n] = m class _FakeQuantize(nn.Module): def __init__(self, fq: torch_q.FakeQuantize) -> None: super().__init__() self.fake_quant_enabled = fq.fake_quant_enabled self.scale = fq.scale self.zero_point = fq.zero_point self.quant_min = fq.quant_min self.quant_max = fq.quant_max self.is_per_channel = fq.is_per_channel self.ch_axis = fq.ch_axis def forward(self, X): if self.fake_quant_enabled[0] == 1: if self.is_per_channel: X = torch.fake_quantize_per_channel_affine( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max ) else: X = torch.fake_quantize_per_tensor_affine( X, self.scale, self.zero_point, self.quant_min, self.quant_max ) return X action_list = [] for n, m in q_model.named_modules(): if isinstance(m, torch_q.FakeQuantize): if '.' in n: n_split = n.split('.') prop = n_split[-1] pm_name = '.'.join(n_split[:-1]) pm = mod_dict[pm_name] else: prop = n pm = q_model new_m = _FakeQuantize(m) action_list.append((pm, prop, new_m)) for parent, prop, new_mod in action_list: setattr(parent, prop, new_mod) for n, m in q_model.named_modules(): def conv_fused_wrapper(mod, scale_factor): type_name = type(mod).__name__ def new_wrapper(input): weight_shape = [1] * len(mod.weight.shape) weight_shape[0] = -1 y = type(mod)._conv_forward( mod, input, mod.weight_fake_quant(mod.weight * scale_factor.reshape(weight_shape)), mod.bias ) if 'Bn' in type_name: y = mod.bn(y) if 'ReLU' in type_name: y = torch.nn.functional.relu(y) return y return new_wrapper type_name = type(m).__name__ if type_name in ('ConvBn1d', 'ConvBnReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d'): running_std = torch.sqrt(m.bn.running_var + m.bn.eps) scale_factor = m.bn.weight / running_std if m.bias is not None: m.bn.running_mean -= m.bias m.bias = None m.bn.running_mean *= scale_factor m.bn.weight /= scale_factor m.forward = conv_fused_wrapper(m, scale_factor) def get_pre_hook(acps, indices): def pre_hook(module, input): new_input = list(input) for acp, idx in zip(acps, indices): new_input[idx] = acp(new_input[idx]) return tuple(new_input) return pre_hook def get_hook_func(acps, indices, orig_func): def pre_hook(*input, **kwargs): new_input = list(input) for acp, idx in zip(acps, indices): if orig_func.__name__ == 'cat': new_input[0][idx] = acp(new_input[0][idx]) else: new_input[idx] = acp(new_input[idx]) input = tuple(new_input) return orig_func(*input, **kwargs) return pre_hook if self.backend == 'onnx': end_mod_dict = {} for start_mod, end_mod, idx in self.swap_nodes: end_mod_dict.setdefault(end_mod, ([], [])) end_mod_dict[end_mod][0].append(start_mod) end_mod_dict[end_mod][1].append(idx) for end_mod, (start_mods, indices) in end_mod_dict.items(): acps = [] for start_mod in start_mods: if inspect.isroutine(start_mod) and isinstance(start_mod.__self__, nnq.FloatFunctional): acp = start_mod.__self__.activation_post_process else: acp = start_mod.activation_post_process acps.append(acp) if inspect.isroutine(end_mod) and isinstance(end_mod.__self__, nnq.FloatFunctional): ff = end_mod.__self__ for i in range(len(indices)): indices[i] -= 1 ff.cat = get_hook_func(acps, indices, ff.cat) ff.add = get_hook_func(acps, indices, ff.add) ff.mul = get_hook_func(acps, indices, ff.mul) ff.add_scalar = get_hook_func(acps, indices, ff.add_scalar) ff.mul_scalar = get_hook_func(acps, indices, ff.mul_scalar) ff.add_relu = get_hook_func(acps, indices, ff.add_relu) else: assert isinstance( end_mod, nn.Module ), "Only end nodes with `nn.Module` are supported duing module swapping" end_mod.register_forward_pre_hook(get_pre_hook(acps, indices)) q_model.apply(torch_q.disable_observer) def freeze_weights(self, q_model, dummy_input): """Freezes the weights before model export without `quantizer.convert` Args: q_model (nn.Module): The QAT/PTQ-prepared model """ hooks = [] def collect_fake_quantize_hook(mod, inp, outp): mod.register_parameter('_tinynn_frozen_q_weight', nn.Parameter(outp)) for n, m in q_model.named_modules(): if n.endswith('.weight_fake_quant'): hooks.append(m.register_forward_hook(collect_fake_quantize_hook)) if isinstance(dummy_input, (list, tuple)): q_model(*dummy_input) else: q_model(dummy_input) for hook in hooks: hook.remove() hooks.clear() def freeze_fake_quantize_hook(mod, inp): return mod._tinynn_frozen_q_weight for n, m in q_model.named_modules(): if n.endswith('.weight_fake_quant'): hooks.append(m.register_forward_pre_hook(freeze_fake_quantize_hook)) def fake_quantize_lstm_weights(self, asym=False, eps=1e-6): def _lstm_weight_fake_quantize(weight, asym=False, eps=1e-6): if weight.numel() == 0: return weight quant_min, quant_max = -127, 127 weight_parts = torch.chunk(weight, 4) weight_quant_parts = [] for i in range(4): weight_quant_parts.append( fake_quantize(weight_parts[i], asym=asym, eps=eps, quant_min=quant_min, quant_max=quant_max) ) weight_quant = torch.cat(weight_quant_parts) return weight_quant with torch.no_grad(): for name, module in self.model.named_modules(): if isinstance(module, torch.nn.LSTM): for weight_name in ['weight_ih_l0', 'weight_hh_l0']: quantized_weight = _lstm_weight_fake_quantize(getattr(module, weight_name), asym=asym, eps=eps) self.lstm_origin_weight_dict[f"{name}.{weight_name}"] = getattr(module, weight_name).clone() getattr(module, weight_name).data.copy_(quantized_weight) def restore_lstm_weights(self, model): with torch.no_grad(): for name, module in model.named_modules(): if isinstance(module, torch.nn.LSTM): for weight_name in ['weight_ih_l0', 'weight_hh_l0']: full_weight_name = f"{name}.{weight_name}" if full_weight_name in self.lstm_origin_weight_dict: getattr(module, weight_name).data.copy_(self.lstm_origin_weight_dict[full_weight_name]) self.lstm_origin_weight_dict.clear() class BF16Quantizer(QATQuantizer): def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, config: typing.Optional[dict] = None): """ Constructs a new BF16Quantizer object Args: model: The model to be quantized dummy_input: A viable input to the model work_dir (typing.Optional[str], optional): The working directory in which the intermediate files will be \ generated. Defaults to None, in which case "output" will be used. config (typing.Optional[dict]): Options for the quantizer """ super().__init__(model, dummy_input, work_dir, config) def parse_config(self, config: dict): super().parse_config(config) self.rewrite_graph = False def quantize(self) -> nn.Module: """Prepare model for BFloat16 training""" self.model.train() qconfig = torch_q.QConfig(activation=FakeQuantizeBFloat16.with_args(), weight=FakeQuantizeBFloat16.with_args()) self.model.qconfig = qconfig torch_q.prepare_qat(self.model, inplace=True) return self.model class PostQuantizer(QATQuantizer): rewrite_graph: bool force_overwrite: bool is_input_quantized: typing.Optional[typing.Tuple[bool]] quantized_input_stats: typing.Optional[typing.List[typing.Optional[typing.Tuple[float, float]]]] backend: str remove_weights_after_load: bool asymmetric: bool per_tensor: bool disable_requantization_for_cat: bool dynamic_lstm_quant: bool algorithm: str def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, config: typing.Optional[dict] = None): """ Constructs a new PostQuantizer object Args: model: The model to be quantized dummy_input: A viable input to the model work_dir (typing.Optional[str], optional): The working directory in which the intermediate files will be \ generated. Defaults to None, in which case "output" will be used. config (typing.Optional[dict]): Options for the quantizer """ super().__init__(model, dummy_input, work_dir, config) def parse_config(self, config: typing.Optional[dict]): if config and 'algorithm' not in config.keys(): config['algorithm'] = 'l2' super().parse_config(config) def set_model_mode(self, model: nn.Module = None, preserve: bool = False): """Sets the mode (train/eval) of the model Args: model (nn.Module, optional): A PyTorch model. Defaults to None. preserve (bool, optional): Whether the old state is preserved. Defaults to False. """ if model is None: model = self.model if preserve: for m in model.modules(): training = getattr(m, 'training', None) if isinstance(training, bool): self.train_mode_dict[m] = training model.eval() def prepare_qconfig(self, graph: TraceGraph, backend: str): """Prepare qconfig for various configurations. Args: graph (TraceGraph): The computation graph of the model backend (str, optional): The backend of quantization """ log.info('setting qat backend and call prepare_qat') if not self.legacy_fq or LooseVersion(torch.__version__) < '1.12.0': qconfig = torch_q.get_default_qconfig(backend) else: qconfig = torch_q.get_default_qconfig(backend, 0) qconfig_c = None if self.backend == 'qnnpack': if not self.asymmetric: sym_fq = torch_q.HistogramObserver.with_args( dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, reduce_range=False ) qconfig = torch_q.QConfig(sym_fq, qconfig.weight) if not self.per_tensor: sym_fq = MinMaxObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False ) qconfig = torch_q.QConfig(qconfig.activation, sym_fq) sym_fq = PerChannelMinMaxObserver.with_args( dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0 ) qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq) elif self.backend == 'fbgemm': sym_fq = torch_q.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric) qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq) elif self.backend in ('onnx', 'tensorrt'): if not self.asymmetric: sym_fq = torch_q.HistogramObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False ) qconfig = torch_q.QConfig(sym_fq, qconfig.weight) if not self.per_tensor: sym_fq = torch_q.MinMaxObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False ) qconfig = torch_q.QConfig(qconfig.activation, sym_fq) sym_fq = torch_q.PerChannelMinMaxObserver.with_args( dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0 ) qconfig_c = torch_q.QConfig(qconfig.activation, sym_fq) else: log.warning(f'Quantization backend {self.backend} is not tested. Please use at your risk.') if self.algorithm != 'l2': if self.algorithm == 'kl': if self.backend == 'qnnpack': alg_sym_fq = HistogramObserverKL.with_args(qscheme=torch.per_tensor_symmetric, reduce_range=False) alg_asym_fq = HistogramObserverKL.with_args(reduce_range=False) elif self.backend == 'fbgemm': alg_sym_fq = HistogramObserverKL.with_args(qscheme=torch.per_tensor_symmetric, reduce_range=True) alg_asym_fq = HistogramObserverKL.with_args(reduce_range=True) else: alg_sym_fq = qconfig.activation alg_asym_fq = qconfig.activation if not self.asymmetric: qconfig = torch_q.QConfig(alg_sym_fq, qconfig.weight) else: qconfig = torch_q.QConfig(alg_asym_fq, qconfig.weight) torch.backends.quantized.engine = backend graph.module.qconfig = qconfig if self.backend == 'qnnpack': if qconfig_c is not None: q = queue.Queue() q.put(graph.module) while not q.empty(): m = q.get() if type(m).__name__ in ( 'Conv2d', 'ConvBnReLU2d', 'ConvBn2d', 'ConvReLU2d', 'Conv1d', 'ConvBnReLU1d', 'ConvBn1d', ): m.qconfig = qconfig_c else: for c in m.children(): q.put(c) elif self.backend == 'fbgemm': if qconfig_c is not None: q = queue.Queue() q.put(graph.module) while not q.empty(): m = q.get() if type(m).__name__ in ('Linear', 'LinearReLU'): m.qconfig = qconfig_c else: for c in m.children(): q.put(c) def _lstm_node(node, custom_data): return isinstance(node.module, nn.LSTM) if self.dynamic_lstm_quant: lstm_nodes = graph.filter_forward_nodes(_lstm_node) for node in lstm_nodes: node.quantized = True node.module.qconfig = torch_q.default_dynamic_qconfig def prepare_qat( self, graph: TraceGraph, is_input_quantized: typing.Optional[typing.Tuple[bool]] = None, backend: str = 'qnnpack', fuse_only: bool = False, ) -> torch.nn.Module: """Prepare model for QAT training Args: graph (TraceGraph): The computation graph of the model is_input_quantized (typing.Union[typing.Tuple[bool]], optional): Whether the input tensor(s) is (are) \ quantized. Defaults to None. backend (str, optional): The backend of quantization. Defaults to 'qnnpack'. fuse_only (bool, optional): Whether the returned model is only fused in PostQuantizer. Defaults to False. Returns: torch.nn.Module: The QAT-ready model """ graph.module.eval() self.prepare_qat_prep(graph, is_input_quantized, backend) if fuse_only: return graph.module # Unfornately, the suggested way below will try to fuse all the modules # even if some of the nodes are not in a quantized computation graph. # So we wrote some alternatives for the function. # torch.quantization.prepare(graph.module, inplace=True) if hasattr(torch_q, 'get_default_qconfig_propagation_list'): whitelist = torch_q.get_default_qconfig_propagation_list() elif hasattr(torch_q, 'get_qconfig_propagation_list'): whitelist = torch_q.get_qconfig_propagation_list() else: whitelist = torch_q.DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST if LooseVersion(torch.__version__) < LooseVersion("1.7.0"): torch_q.prepare(graph.module, inplace=True) if hasattr(graph.module, 'activation_post_process'): if hasattr(graph.module, "_forward_hooks"): if len(graph.module._forward_hooks) > 0: graph.module._forward_hooks.popitem() delattr(graph.module, "activation_post_process") else: torch_q.propagate_qconfig_(graph.module, qconfig_dict=None) for n in graph.forward_nodes: if not n.quantized: if hasattr(n.module, "qconfig"): delattr(n.module, "qconfig") if hasattr(torch_q, 'add_observer_'): add_observer_func = torch_q.add_observer_ else: add_observer_func = sys.modules['torch.ao.quantization.quantize']._add_observer_ if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): if LooseVersion(torch.__version__) >= LooseVersion("1.13.0"): prepare_custom_config_dict = torch.ao.quantization.get_default_custom_config_dict() else: prepare_custom_config_dict = {} custom_module_class_mapping = prepare_custom_config_dict.get( "float_to_observed_custom_module_class", {} ) custom_module_class_mapping.update(FUSE_PTQ_MODULES_CUSTOM) def patch_observer_set(orig_func): def new_no_observer_set(): return set(FUSE_PTQ_MODULES_CUSTOM.values()) | orig_func() return new_no_observer_set if LooseVersion(torch.__version__) >= LooseVersion("1.10.0"): orig_no_observer_set = sys.modules['torch.ao.quantization.quantize'].no_observer_set sys.modules['torch.ao.quantization.quantize'].no_observer_set = patch_observer_set( orig_no_observer_set ) elif LooseVersion(torch.__version__) >= LooseVersion("1.9.0"): orig_no_observer_set = torch.quantization.quantization_mappings.no_observer_set torch.quantization.quantization_mappings.no_observer_set = patch_observer_set(orig_no_observer_set) add_observer_func( graph.module, qconfig_propagation_list=whitelist, custom_module_class_mapping=custom_module_class_mapping, ) else: add_observer_func( graph.module, qconfig_propagation_list=whitelist, ) if self.dynamic_lstm_quant: mapping = {nn.LSTM: nnqd.LSTM} torch_q.convert(graph.module, mapping=mapping, inplace=True, remove_qconfig=False) for n in graph.forward_nodes: if not n.quantized: if hasattr(n.module, "_forward_hooks"): if len(n.module._forward_hooks) > 0: n.module._forward_hooks.popitem() if hasattr(n.module, "qconfig"): delattr(n.module, "qconfig") if hasattr(n.module, "activation_post_process"): delattr(n.module, "activation_post_process") if self.disable_requantization_for_cat: self.disable_requantization_for_cat_pass(graph) if self.quantized_input_stats is not None: self.prepare_quantized_inputs_pass(graph) if self.set_quantizable_op_stats: if self.quantized_op_stats is None: self.quantized_op_stats = {} self.quantized_op_stats.update(KNOWN_QSTATS) if self.quantized_op_stats is not None: self.prepare_quantized_ops_pass(graph) if self.quantize_op_action.get(nn.LSTM, None) and self.backend == 'qnnpack': self.fake_quantize_lstm_weights() return graph.module class DynamicQuantizer(QATQuantizer): rewrite_graph: bool force_overwrite: bool is_input_quantized: typing.Optional[typing.Tuple[bool]] backend: str remove_weights_after_load: bool def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, config: typing.Optional[dict] = None): """ Constructs a new DynamicQuantizer object Args: model: The model to be quantized dummy_input: A viable input to the model work_dir (typing.Optional[str], optional): The working directory in which the intermediate files will be \ generated. Defaults to None, in which case "output" will be used. config (typing.Optional[dict]): Options for the quantizer """ super().__init__(model, dummy_input, work_dir, config) def parse_config(self, config: dict): super().parse_config(config) self.rewrite_graph = False assert not self.asymmetric, "Asymmetric quantization is not supported for DynamicQuantizer" assert self.per_tensor, "Per-channel quantization is not supported for DynamicQuantizer" def quantize(self) -> nn.Module: """Prepare model for dynamic quantization""" self.model.eval() torch_q.quantize_dynamic(self.model, inplace=True) return self.model class DeQuantizer(object): rewrite_graph: bool force_overwrite: bool remove_weights_after_load: bool extra_tracer_opts: typing.Optional[typing.Dict] target: str def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, config: typing.Optional[dict] = None): """ Constructs a new DeQuantizer object Args: model: The model to be dequantized dummy_input: A viable input to the model work_dir (typing.Optional[str], optional): The working directory in which the intermediate files will be \ generated. Defaults to None, in which case "output" will be used. config (typing.Optional[dict]): Options for the dequantizer """ if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): self.model = model.module else: self.model = model self.dummy_input = dummy_input self.work_dir = 'out' if work_dir is None else work_dir self.parse_config(config) def dequantize(self) -> nn.Module: # We need a model in training mode so that QAT could take place if self.target == 'qat': self.model.train() else: self.model.eval() # After tracing the model, we will get a TraceGraph object graph = trace(self.model, self.dummy_input, **self.extra_tracer_opts) if self.rewrite_graph: # Retrives the name of the model from type info model_name = type(self.model).__name__ if model_name.endswith('_qat'): model_name = model_name[:-4] model_name_float = f'{model_name}_float' model_name_float_lower = model_name_float.lower() model_ns = f'tinynn_rewritten_models.{model_name_float}' model_code_path = os.path.join(self.work_dir, f'{model_name_float_lower}.py') model_weights_path = os.path.join(self.work_dir, f'{model_name_float_lower}.pth') # Generate the code for the modified model # We will try to do some op rewriting in the `rewrite_quantize_graph` function. if self.force_overwrite or not os.path.exists(model_code_path) or not os.path.exists(model_weights_path): self.rewrite_quantize_graph(graph) graph.generate_code(model_code_path, model_weights_path, model_name_float) # Import the new model rewritten_model = import_from_path(model_ns, model_code_path, model_name_float)() rewritten_model.load_state_dict(torch.load(model_weights_path)) device = get_module_device(self.model) if device is not None: rewritten_model.to(device=device) # Remove the weights file to save space if self.remove_weights_after_load: os.unlink(model_weights_path) # Set the model to training mode before tracing again, since we need to perform QAT if self.target == 'qat': rewritten_model.train() else: rewritten_model.eval() # Update the model so that the original one can be released self.model = rewritten_model else: rewritten_model = self.model return rewritten_model def parse_config(self, config: typing.Optional[dict]): default_values = { 'rewrite_graph': True, 'force_overwrite': True, 'extra_tracer_opts': {}, 'remove_weights_after_load': False, 'target': 'qat', } if config is None: config = dict() for k, v in default_values.items(): actual_v = config.get(k, v) setattr(self, k, actual_v) def rewrite_quantize_graph(self, graph: TraceGraph) -> None: """Rewrites the computation graph for quantization""" if graph.quantized: return graph_quantized = False for n in graph.forward_nodes: if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub): graph_quantized = True break for n in graph.other_init_nodes: if n.type() is nnq.FloatFunctional: graph_quantized = True break if not graph_quantized: log.warning( 'Graph is not quantized. No need to rewrite. Please pass in `config={"rewrite_graph": False}` to' ' suppress this warning' ) return # Remove QuantStubs and DeQuantStubs names = [n.unique_name for n in graph.forward_nodes] for name in names: n = graph.nodes_map[name] if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub): graph.remove_node(n) def _is_add_relu_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind == 'add_relu' and len(node.prev_nodes) > 1 and node.prev_nodes[0].type() is nnq.FloatFunctional ) # Split FloatFunctional.add_relu to FloatFunctional.add and torch.relu add_relu_nodes = graph.filter_forward_nodes(_is_add_relu_node) for n in add_relu_nodes: n.module.kind = 'add' n.module.func_type = 'add' parts = n.module.full_name.split('.')[:-1] + [n.module.func_type] n.module.full_name = '.'.join(parts) with override_current_trace_graph(graph): next_func = TraceFunction('torch.relu').parse_args(n.next_tensors[0]) next_out = torch.relu(n.next_tensors[0]) graph.insert_after(n, next_func, [next_out]) # Replace FloatFunctional.{add_scalar, mul_scalar} with torch.{add, mul} def _is_add_mul_scalar_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind in ('add_scalar', 'mul_scalar') and len(node.prev_nodes) > 1 and node.prev_nodes[0].type() is nnq.FloatFunctional ) add_mul_scalar_nodes = graph.filter_forward_nodes(_is_add_mul_scalar_node) for n in add_mul_scalar_nodes: n.module.kind = n.module.kind.split('_')[0] n.module.func_type = n.module.kind parts = n.module.full_name.split('.')[:-1] + [n.module.func_type] n.module.full_name = '.'.join(parts) # Replace FloatFunctional.{add, mul, cat} with torch.{add, mul, cat} def _is_add_mul_cat_node(node: TraceNode, custom_data): cur_module = node.module cur_class = type(cur_module) if cur_class == TraceFunction: return ( cur_module.kind in ('add', 'mul', 'cat') and len(node.prev_nodes) > 1 and node.prev_nodes[0].type() is nnq.FloatFunctional ) add_mul_cat_nodes = graph.filter_forward_nodes(_is_add_mul_cat_node) for n in add_mul_cat_nodes: parts = ['torch'] + [n.module.func_type] n.module.full_name = '.'.join(parts) n.module.is_class = False n.prev_nodes[0].next_nodes.remove(n) n.prev_nodes.pop(0) n.prev_indices.pop(0) n.module.tensor_names.pop(0) n.module.args_parsed.pop(0) n.module.args_to_string(n.module.args_parsed) # Remove FloatFunctional nodes names = [n.unique_name for n in graph.other_init_nodes] for name in names: n = graph.nodes_map[name] if n.type() is nnq.FloatFunctional: graph.other_init_nodes.remove(n) del graph.nodes_map[n.unique_name] graph.quantized = False graph.recompute_forward_order() def load_creation_func_names(): if len(creation_func_names) == 0: funcs_d = load_creation_funcs() for ns, funcs_v in funcs_d.items(): ns_str = qualified_name(ns) creation_func_names.extend([f'{ns_str}.{x}' for x in funcs_v]) return creation_func_names def get_dict_from_rules(rules): rule_dict = {} for rule in rules: base_rule_dict = rule_dict for module_cls in reversed(rule): # Node properties (has_key, child_nodes) base_rule_dict.setdefault(module_cls, [False, {}]) base_rule_pair = base_rule_dict[module_cls] base_rule_dict = base_rule_pair[1] base_rule_pair[0] = True return rule_dict def load_processed_qat_rules(): if len(processed_qat_rules) == 0: # Constructor a prefix tree for the QAT rules fuse_rules = sorted(FUSE_RULE_LIST, key=lambda x: len(x), reverse=True) rule_dict = get_dict_from_rules(fuse_rules) processed_qat_rules.update(rule_dict) return processed_qat_rules def load_processed_ptq_rules(): if len(processed_ptq_rules) == 0: # Constructor a prefix tree for the PTQ rules filtered_ptq_rules = { k for k, v in FUSE_RULE_LIST_PTQ_ONLY.items() if v is None or LooseVersion(torch.__version__) >= LooseVersion(v) } ptq_rules = set(FUSE_RULE_LIST).union(set(filtered_ptq_rules)) fuse_rules = sorted(ptq_rules, key=lambda x: len(x), reverse=True) rule_dict = get_dict_from_rules(fuse_rules) processed_ptq_rules.update(rule_dict) return processed_ptq_rules def load_processed_extra_q_rules(): if len(processed_extra_q_rules) == 0: # Constructor a prefix tree for the QAT rules fuse_rules = sorted(FUSE_RULE_LIST_EXTRA, key=lambda x: len(x), reverse=True) rule_dict = get_dict_from_rules(fuse_rules) processed_extra_q_rules.update(rule_dict) return processed_extra_q_rules def load_processed_rewrite_to_fuse_rules(): if len(processed_rewrite_to_fuse_rules) == 0: # Constructor a prefix tree for the rewrite rules fuse_rules = sorted(REWRITE_TO_FUSE_RULE_LIST, key=lambda x: len(x), reverse=True) rule_dict = get_dict_from_rules(fuse_rules) processed_rewrite_to_fuse_rules.update(rule_dict) return processed_rewrite_to_fuse_rules def load_processed_all_qat_rules(): if len(processed_all_qat_rules) == 0: # Constructor a prefix tree for the QAT rules qat_rules = set(FUSE_RULE_LIST).union(set(FUSE_RULE_LIST_EXTRA)) fuse_rules = sorted(qat_rules, key=lambda x: len(x), reverse=True) rule_dict = get_dict_from_rules(fuse_rules) processed_all_qat_rules.update(rule_dict) return processed_all_qat_rules def load_processed_all_ptq_rules(): if len(processed_all_ptq_rules) == 0: # Constructor a prefix tree for the QAT rules filtered_ptq_rules = { k for k, v in FUSE_RULE_LIST_PTQ_ONLY.items() if v is None or LooseVersion(torch.__version__) >= LooseVersion(v) } ptq_rules = set(FUSE_RULE_LIST).union(set(filtered_ptq_rules)).union(set(FUSE_RULE_LIST_EXTRA)) fuse_rules = sorted(ptq_rules, key=lambda x: len(x), reverse=True) rule_dict = get_dict_from_rules(fuse_rules) processed_all_ptq_rules.update(rule_dict) return processed_all_ptq_rules