tinynn/converter/base.py (475 lines of code) (raw):

import collections import io import os import typing import torch import numpy as np from .operators import CommonGraph, ExtendedOperator, GraphOptimizer, HybridQuantizer, HalfQuantizer from .operators.op_version import OPVersioner from .operators.tflite import Tensor from .operators.torch import OPERATOR_CONVERTER_DICT from .operators.torch.base import NoTrackOperator, TrackRevQParamsOperator, TrackQParamsOperator from .operators.torch.aten import ATenDequantizeOperator, ATenQuantizePerTensorOperator from ..util.converter_util import generate_converter_config from ..util.util import get_logger log = get_logger(__name__, 'INFO') class TFLiteConverter(object): def __init__( self, model: typing.Union[torch.jit.ScriptFunction, torch.jit.ScriptModule, torch.nn.Module], dummy_input: typing.Union[torch.Tensor, typing.Iterable[torch.Tensor]], tflite_path: str, input_transpose: typing.Optional[typing.Union[bool, typing.Iterable[bool]]] = None, output_transpose: typing.Optional[typing.Union[bool, typing.Iterable[bool]]] = None, nchw_transpose: bool = True, dump_jit_model_path: typing.Optional[str] = None, dump_dummy_input_path: typing.Optional[str] = None, dump_config_path: typing.Optional[str] = None, strict_symmetric_check: bool = False, preserve_tensors: bool = False, optimize: int = GraphOptimizer.ALL_OPTIMIZE, quantize_target_type: str = 'uint8', quantize_input_output_type: typing.Optional[str] = None, hybrid_quantization_from_float: bool = False, hybrid_per_channel: bool = False, hybrid_asymmetric_inputs: bool = True, hybrid_quantize_weight_type: typing.Optional[str] = None, fuse_quant_dequant: bool = False, fuse_input_indices: typing.Optional[typing.List[int]] = None, fuse_output_indices: typing.Optional[typing.List[int]] = None, gc_when_reload: bool = False, group_conv_rewrite: bool = False, rewrite_quantizable: bool = False, tflite_micro_rewrite: bool = False, map_bilstm_to_lstm: bool = False, float16_quantization: bool = False, enable_mtk_ops: bool = False, conv_transpose_with_bias: bool = True, max_transpose_dims: int = -1, hybrid_conv: bool = True, hybrid_int16_lstm: bool = False, unroll_rnn: bool = False, separated_rnn_gate_calc: bool = False, bypass_elementwise_passthrough_constraint: bool = False, hybrid_gen_single_op_models: bool = False, hybrid_config: typing.Optional[typing.Dict[str, bool]] = None, group_tensors: bool = False, missing_outputs_as_constants: bool = False, legacy_gelu: bool = False, ) -> None: """ The TFLiteConverter class Args: model (typing.Union[torch.jit.ScriptFunction, torch.jit.ScriptModule, torch.nn.Module]): The input model \ (either traced or non-traced) dummy_input (typing.Union[torch.Tensor, typing.Iterable[torch.Tensor]]): A viable input to the model tflite_path (str): Path to use for exporting input_transpose (typing.Optional[typing.Union[bool, typing.Iterable[bool]]], optional): Whether to \ transpose the input(s). Defaults to None(True for 4d-input, False otherwise). output_transpose (typing.Optional[typing.Union[bool, typing.Iterable[bool]]], optional): Whether to \ transpose the output(s). Defaults to None(True for 4d-input, False otherwise). nchw_transpose (bool): Whether to perform nchw->nhwc transposes on input and output tensors. \ `False` is specified, the arguments `input_transpose` and `output_transpose` will be ignored. dump_jit_model_path (typing.Optional[str]): The path for dumping the jit model. Defaults to None dump_dummy_input_path (typing.Optional[str]): The path for dumping the dummy input. Defaults to None dump_config_path (typing.Optional[str]): The path for dumping the json config. Defaults to None strict_symmetric_check (bool): Strict symmetric quantization checks. Defaults to False preserve_tensors (bool): Preserve the copies of the intermediate tensors. Defaults to False optimize (int): The level of graph optimization. Defaults to `GraphOptimizer.ALL_OPTIMIZE` quantize_target_type (str): Target type for quantization. Defaults to 'uint8' quantize_input_output_type (str): Input and output type for quantization. Defaults to None (inferred) hybrid_quantization_from_float (bool): Direct hybrid quantization from a float model. Defaults to False hybrid_per_channel (bool): Prefer per-channel kernels in hybrid quantization. Defaults to False hybrid_asymmetric_inputs (bool): Prefer asymmetric inputs while performing hybrid quantization hybrid_quantize_weight_type (typing.Optional[str]): Quantized weight type for hybrid quantization. \ If it is unset, then the value of `quantize_target_type` will be used. Defaults to None fuse_quant_dequant (bool): Remove quant and dequant nodes directly connected to i/o nodes. Defaults to False fuse_input_indices (typing.Optional[typing.List[int]]): Used together with `fuse_quant_dequant`. Indices \ of input nodes to fuse with `Quantize`. Defaults to None (which fuses all inputs available) fuse_output_indices (typing.Optional[typing.List[int]]): Used together with `fuse_quant_dequant`. Indices \ of output nodes to fuse with `Dequantize`. Defaults to None (which fuses all outputs available) gc_when_reload (bool): Apply GC when reloading the torchscript into memory group_conv_rewrite (bool): Rewriting for group [de]convolution. Defaults to False rewrite_quantizable (bool): Rewriting quantizable ops (e.g. BATCH_MATMUL, SOFTMAX, LOG_SOFTMAX) \ to use quantized kernels. Defaults to False tflite_micro_rewrite (bool): Rewriting for running on TFLite-micro. Defaults to False map_bilstm_to_lstm (bool): Translating bidirectional LSTM to TFLite ops with `UnidirectionalLSTM`. \ Defaults to False float16_quantization (bool): Quantize constants with float32 dtype to floa16 dtype. Defaults to False enable_mtk_ops (bool): Translating with custom MTK operators. Defaults to False conv_transpose_with_bias (bool): ConvTranspose ops with bias. Defaults to True max_transpose_dims (int): Max dimensions for the `Transpose` op. Defaults to -1, which means unlimited hybrid_conv (bool): Enable hybrid quantization for Conv2d and DepthwiseConv2d. Defaults to True hybrid_int16_lstm (bool): Enable hybrid int16 quantization for LSTM. Defaults to False unroll_rnn (bool): Unrolling LSTM (translate LSTM to seperate ops). Defaults to False separated_rnn_gate_calc (bool): Separated calculation for every gate in RNN. Effective only when \ `unroll_rnn=True`. Defaults to False bypass_elementwise_passthrough_constraint (bool): Bypass constraints in elementwise passthrough passes. \ Defaults to False hybrid_gen_single_op_models: Generate both floating point and quantized version of the model for hybrid \ quantizable ops. Defaults to False group_tensors (bool): Group tensors to save space. Defaults to False missing_outputs_as_constants (bool): View missing outputs as constants. Defaults to False legacy_gelu (bool): Fallback to the legacy behaviour for translating gelu. Defaults to False """ self.model = model self.lower_model = None self.graph = None self.tensor_map = {} self.tensor_map_copies = {} self.common_graph = CommonGraph() if type(dummy_input) in (tuple, list): self.dummy_input = dummy_input else: self.dummy_input = [dummy_input] self.flatten_inputs = [] self.tflite_path = tflite_path self.nchw_transpose = nchw_transpose if self.nchw_transpose: self.input_transpose = input_transpose self.output_transpose = output_transpose else: self.input_transpose = False self.output_transpose = False self.strict_symmetric_check = strict_symmetric_check self.dump_jit_model_path = dump_jit_model_path self.dump_dummy_input_path = dump_dummy_input_path self.dump_config_path = dump_config_path self.preserve_tensors = preserve_tensors self.optimize = optimize self.hybrid = hybrid_quantization_from_float self.hybrid_per_channel = hybrid_per_channel self.hybrid_asymmetric_inputs = hybrid_asymmetric_inputs self.fuse_quant_dequant = fuse_quant_dequant self.fuse_input_indices = fuse_input_indices self.fuse_output_indices = fuse_output_indices self.gc_when_reload = gc_when_reload self.group_conv_rewrite = group_conv_rewrite self.rewrite_quantizable = rewrite_quantizable self.tflite_micro_rewrite = tflite_micro_rewrite self.map_bilstm_to_lstm = map_bilstm_to_lstm self.float16_quantization = float16_quantization self.enable_mtk_ops = enable_mtk_ops self.conv_transpose_with_bias = conv_transpose_with_bias self.max_transpose_dims = max_transpose_dims self.hybrid_conv = hybrid_conv self.hybrid_int16_lstm = hybrid_int16_lstm self.unroll_rnn = unroll_rnn self.separated_rnn_gate_calc = separated_rnn_gate_calc self.bypass_elementwise_passthrough_constraint = bypass_elementwise_passthrough_constraint self.hybrid_gen_single_op_models = hybrid_gen_single_op_models self.hybrid_config = hybrid_config self.group_tensors = group_tensors self.missing_outputs_as_constants = missing_outputs_as_constants self.legacy_gelu = legacy_gelu if quantize_target_type == 'uint8': self.q_type = np.uint8 if self.strict_symmetric_check: log.warning('Symmetric quantized model with uint8 is unsupported in most backends of TFLite') elif quantize_target_type == 'int8': self.q_type = np.int8 elif quantize_target_type == 'int16': if not self.strict_symmetric_check: raise AttributeError('Int16 quantization requires strict_symmetric_check=True') self.q_type = np.int16 else: raise AttributeError(f'unknown quantize_target_type: {quantize_target_type}, expected: uint8, int8, int16') if quantize_input_output_type is not None: assert fuse_quant_dequant, 'Please set fuse_quant_dequant=True, otherwise quantize_input_type is ignored' assert quantize_input_output_type in ( 'int8', 'uint8', 'int16', ), f'unknown quantize_input_output_type: {quantize_input_output_type}, expected: uint8, int8, int16' if quantize_input_output_type == 'int16' and quantize_target_type != 'int16': raise AttributeError( 'quantize_input_output_type == \'int16\' and quantize_target_type != \'int16\' is not supported' ) self.quantize_input_output_type = quantize_input_output_type if hybrid_quantize_weight_type is None: hybrid_quantize_weight_type = quantize_target_type if hybrid_quantize_weight_type == 'uint8': if self.hybrid: if self.hybrid_per_channel: raise AttributeError('Per-channel kernels supports int8 only') log.warning( 'Unless you are using legacy TFLite (<1.14), please set quantize_target_type to int8 instead' ) self.hybrid_q_type = np.uint8 elif hybrid_quantize_weight_type == 'int8': self.hybrid_q_type = np.int8 elif hybrid_quantize_weight_type == 'int16': self.hybrid_q_type = np.int16 if self.hybrid: raise AttributeError('Hybrid kernels supports int8 and uint8 only') if dump_config_path and not dump_jit_model_path: raise AssertionError("when dump_config_path is set, dump_jit_model_path is required to be set") self.input_offset = 1 def init_jit_graph(self): # Multi-GPU modules doesn't support JIT tracing if isinstance(self.model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): self.model = self.model.module if not isinstance(self.model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): if hasattr(self.model, 'cpu'): self.model.cpu() if hasattr(self.model, 'eval'): self.model.eval() with torch.no_grad(): script = torch.jit.trace(self.model, self.dummy_input) # Remove reference to original model to save memory self.model = None # Have to save it once, otherwise something weird happens if self.dump_jit_model_path is None: with io.BytesIO() as f: torch.jit.save(script, f) f.seek(0) script = torch.jit.load(f) else: jit_model_dir = os.path.abspath(os.path.dirname(self.dump_jit_model_path)) os.makedirs(jit_model_dir, exist_ok=True) torch.jit.save(script, self.dump_jit_model_path) if self.gc_when_reload: import gc script = None gc.collect() script = torch.jit.load(self.dump_jit_model_path) self.model = script if isinstance(self.model, torch.jit.ScriptFunction): self.input_offset = 0 if self.dump_dummy_input_path is not None: dummy_arrs = list(map(lambda x: x.detach().cpu().numpy(), self.dummy_input)) np.savez(self.dump_dummy_input_path, *dummy_arrs) if self.dump_config_path is not None: generate_converter_config( self.dummy_input, [], self.input_transpose, [], self.dump_jit_model_path, self.tflite_path, self.dump_config_path, ) def init_lowered_module(self): assert ( isinstance(self.model, torch.jit.ScriptFunction) or self.model.training is False or str(next(self.model.graph.inputs()).type()) == '__torch__.PlaceholderModule' ), ( 'Model is in training mode. Please run `model.eval()` before model conversion. If you are passing in a' ' TorchScript model, make sure you use `torch.jit.save` to dump the model to disk and then load it using' ' `torch.jit.load`.' ) graph = self.model.graph # Inline everything torch._C._jit_pass_inline(graph) # Remove fork/wait nodes torch._C._jit_pass_inline_fork_wait(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_lower_all_tuples(graph) # we record now record some ops like ones/zeros # into a trace where we previously recorded constants # use constant prop to maintain our current level of onnx support # without implementing symbolics for all of them torch._C._jit_pass_constant_propagation(graph) # _split_tensor_list_constants(graph, graph) # run dce to eliminate dead parts of the graph that might have been # left behind by things like symbolic_override torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_canonicalize_graph_fuser_ops(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_fuse_addmm(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_lower_all_tuples(graph) self.graph = graph log.debug('Lowered graph:') log.debug(self.graph) def init_flatten_inputs(self): self.flatten_inputs.clear() for t in self.dummy_input: if isinstance(t, (list, tuple)): for rt in t: self.flatten_inputs.append(rt) else: self.flatten_inputs.append(t) def init_input_transpose(self): input_transpose = self.input_transpose if type(input_transpose) not in (tuple, list): input_transpose = [input_transpose] * len(self.flatten_inputs) for i, t in enumerate(self.flatten_inputs): if input_transpose[i] is None: if isinstance(t, torch.Tensor): input_transpose[i] = t.dim() == 4 else: input_transpose[i] = False self.input_transpose = input_transpose def init_common_graph(self): graph_inputs = [x.debugName() for x in list(self.graph.inputs())][self.input_offset :] graph_outputs = [x.debugName() for x in list(self.graph.outputs())] self.common_graph.inputs.extend(graph_inputs) self.common_graph.outputs.extend(graph_outputs) self.common_graph.input_transpose.extend(self.input_transpose) self.common_graph.output_transpose = self.output_transpose tensors = [] for i, node in enumerate(graph_inputs): tensors.append( Tensor( self.flatten_inputs[i], node, has_buffer=False, asymmetric=not self.strict_symmetric_check, q_type=self.q_type, ) ) self.common_graph.add_nodes(tensors, ExtendedOperator.INPUT_NODE) def init_inputs(self): graph_inputs = [x.debugName() for x in list(self.graph.inputs())] for i, node in enumerate(graph_inputs): if self.input_offset > 0 and i == 0: self.tensor_map[graph_inputs[i]] = self.model else: self.tensor_map[graph_inputs[i]] = self.flatten_inputs[i - self.input_offset] def unsupported_operations(self, unique=True) -> typing.List[str]: """Returns unsupported operations in the graph""" if self.graph is None: self.init_lowered_module() all_nodes = list(self.graph.nodes()) ops = [] for node in all_nodes: k = node.kind() converter_type = OPERATOR_CONVERTER_DICT.get(k, None) if converter_type is None: ops.append(k) if unique: return list(set(ops)) else: return ops def init_operations(self): log.debug('Initialize operators...') node_queue = collections.deque(self.graph.nodes()) scope_map = {} current_scope = None while node_queue: node = node_queue.popleft() k = node.kind() output_tensors = [] converter_type = OPERATOR_CONVERTER_DICT.get(k, NoTrackOperator) converter = converter_type( node, self.tensor_map, current_scope, not self.strict_symmetric_check, self.q_type, self.hybrid_q_type, self.map_bilstm_to_lstm, self.enable_mtk_ops, self.hybrid_asymmetric_inputs, self.unroll_rnn, self.separated_rnn_gate_calc, self.conv_transpose_with_bias, self.legacy_gelu, ) # Don't track the operator if all the input nodes are not tracked unless it has custom implementation # (e.g prim::* ops) if converter_type.run == NoTrackOperator.run and converter_type != NoTrackOperator: no_track_flag = True for n in converter.input_names: if self.common_graph.has_nested_names(n): nested_names = self.common_graph.get_list_expanded_names(n) for x in nested_names: if x in self.common_graph.tensor_map and self.common_graph.tensor_map[x].buffer is None: no_track_flag = False break elif n in self.common_graph.tensor_map and self.common_graph.tensor_map[n].buffer is None: no_track_flag = False break if no_track_flag: if converter_type == ATenDequantizeOperator: converter_type = TrackQParamsOperator elif converter_type == ATenQuantizePerTensorOperator: converter_type = TrackRevQParamsOperator else: converter_type = NoTrackOperator converter = converter_type( node, self.tensor_map, current_scope, not self.strict_symmetric_check, self.q_type, self.hybrid_q_type, self.map_bilstm_to_lstm, self.enable_mtk_ops, self.hybrid_asymmetric_inputs, self.unroll_rnn, self.separated_rnn_gate_calc, self.conv_transpose_with_bias, self.legacy_gelu, ) if k != 'prim::Constant': log.debug(f'{k} {converter.input_names} -> {converter.output_names} {converter_type.__name__}') # Don't fetch attrs and schemas for non-tracking nodes if converter_type not in (NoTrackOperator, TrackRevQParamsOperator, TrackQParamsOperator): try: attrs = converter.fetch_all_attrs(node) except StopIteration: attrs = None args = converter.fetch_annotated_args(node) else: attrs = None args = None converter.parse(node, attrs, args, self.common_graph) outputs = converter.output_names new_nodes = converter.output_nodes if output_tensors is not None: output_tensors.extend(converter.get_output_tensors()) if len(new_nodes) > 0: node_queue.extendleft(reversed(new_nodes)) if k == 'prim::PythonOp': s = node.scopeName() scope_map.setdefault(s, 0) scope_map[s] += 1 current_scope = f'{s}_{scope_map[s]}' converter.prepare_scope_tensors(node, attrs, args, self.common_graph, current_scope) elif k == 'prim::Return': current_scope = None assert len(output_tensors) == len(outputs) for t, name in zip(output_tensors, outputs): self.tensor_map[name] = t if self.preserve_tensors and isinstance(t, torch.Tensor): self.tensor_map_copies[name] = t.detach().clone() def __try_infer_type(self, params): try: inferred = torch._C._jit_try_infer_type(params) if hasattr(inferred, 'type'): return inferred.type().annotation_str finally: return str(inferred) def __unpack_params(self, params): return NoTrackOperator.unpack_params(None, params) def convert(self): """Converts the model to the TFLite format Raises: Exception: If unsupported ops are found, an Exception will be raised """ self.init_flatten_inputs() self.init_input_transpose() self.init_jit_graph() self.init_lowered_module() self.init_common_graph() self.init_inputs() self.init_operations() unsupported_ops = self.unsupported_operations() if len(unsupported_ops) > 0: log.error(f'Unsupported ops: {", ".join(unsupported_ops)}') raise Exception("Cannot continue due to fatal error") else: optimizer = GraphOptimizer( self.common_graph, self.optimize, self.fuse_quant_dequant, self.group_conv_rewrite, self.rewrite_quantizable, self.tflite_micro_rewrite, self.quantize_input_output_type, self.fuse_input_indices, self.fuse_output_indices, self.max_transpose_dims, self.bypass_elementwise_passthrough_constraint, self.group_tensors, self.conv_transpose_with_bias, self.hybrid_int16_lstm, ) optimizer.optimize() self.output_transpose = self.common_graph.output_transpose if self.hybrid: quantizer = HybridQuantizer( self.common_graph, self.hybrid_asymmetric_inputs, self.hybrid_q_type, self.hybrid_per_channel, self.hybrid_conv, self.hybrid_int16_lstm, self.hybrid_gen_single_op_models, self.hybrid_config, ) quantizer.quantize() optimizer.cleanup_dead_nodes() if self.float16_quantization: quantizer = HalfQuantizer(self.common_graph) quantizer.quantize() optimizer.cleanup_dead_nodes() versioner = OPVersioner(self.common_graph) versioner.process() if self.missing_outputs_as_constants: tensors = [] for output_name in self.common_graph.outputs: if output_name not in self.common_graph.tensor_map: tensors.append( Tensor( self.tensor_map[output_name], output_name, has_buffer=True, asymmetric=not self.strict_symmetric_check, q_type=self.q_type, ) ) self.common_graph.add_nodes(tensors, ExtendedOperator.CONSTANT_NODE) self.common_graph.add_outputs([t.name for t in tensors]) self.common_graph.convert(self.tflite_path) log.info(f'Generated model saved to {self.tflite_path}') def visualize(self, hide_constants=True): """Visualize the TinyNeuralNetwork Graph Args: hide_constants (bool, optional): Hide the constant nodes in the graph. Defaults to True. """ self.common_graph.visualize(hide_constants) def get_outputs(self): """Returns the output of the model, which is evaluated via tracing nodes one by one""" outputs = [] for name in self.common_graph.outputs: outputs.append(self.tensor_map[name]) return outputs def get_value(self, name, default_val=None): """Returns the output according to the name of the node. If the name doesn't exist, `default_val` is returned""" if self.preserve_tensors: val = self.tensor_map_copies.get(name, default_val) else: val = self.tensor_map.get(name, default_val) type_ = self.__try_infer_type(val) if type_.endswith('PackedParamsBase'): return self.__unpack_params(val) return val def tensor_names(self) -> typing.List[str]: """Returns the all the names of the intermediate tensors Returns: typing.List[str]: The names of the intermediate tensors """ if self.preserve_tensors: return list(self.tensor_map_copies.keys()) else: return list(self.tensor_map.keys()) def inputs_for_tflite(self) -> typing.List[np.ndarray]: """Prepare inputs for the TFLite backend Returns: typing.List[np.ndarray]: The input tensors """ arrs = [] for t, trans in zip(self.dummy_input, self.input_transpose): arr = t.detach().clone().numpy() if trans: arr = np.transpose(arr, (0, 2, 3, 1)) arrs.append(arr) return arrs