tinynn/converter/base.py (475 lines of code) (raw):
import collections
import io
import os
import typing
import torch
import numpy as np
from .operators import CommonGraph, ExtendedOperator, GraphOptimizer, HybridQuantizer, HalfQuantizer
from .operators.op_version import OPVersioner
from .operators.tflite import Tensor
from .operators.torch import OPERATOR_CONVERTER_DICT
from .operators.torch.base import NoTrackOperator, TrackRevQParamsOperator, TrackQParamsOperator
from .operators.torch.aten import ATenDequantizeOperator, ATenQuantizePerTensorOperator
from ..util.converter_util import generate_converter_config
from ..util.util import get_logger
log = get_logger(__name__, 'INFO')
class TFLiteConverter(object):
def __init__(
self,
model: typing.Union[torch.jit.ScriptFunction, torch.jit.ScriptModule, torch.nn.Module],
dummy_input: typing.Union[torch.Tensor, typing.Iterable[torch.Tensor]],
tflite_path: str,
input_transpose: typing.Optional[typing.Union[bool, typing.Iterable[bool]]] = None,
output_transpose: typing.Optional[typing.Union[bool, typing.Iterable[bool]]] = None,
nchw_transpose: bool = True,
dump_jit_model_path: typing.Optional[str] = None,
dump_dummy_input_path: typing.Optional[str] = None,
dump_config_path: typing.Optional[str] = None,
strict_symmetric_check: bool = False,
preserve_tensors: bool = False,
optimize: int = GraphOptimizer.ALL_OPTIMIZE,
quantize_target_type: str = 'uint8',
quantize_input_output_type: typing.Optional[str] = None,
hybrid_quantization_from_float: bool = False,
hybrid_per_channel: bool = False,
hybrid_asymmetric_inputs: bool = True,
hybrid_quantize_weight_type: typing.Optional[str] = None,
fuse_quant_dequant: bool = False,
fuse_input_indices: typing.Optional[typing.List[int]] = None,
fuse_output_indices: typing.Optional[typing.List[int]] = None,
gc_when_reload: bool = False,
group_conv_rewrite: bool = False,
rewrite_quantizable: bool = False,
tflite_micro_rewrite: bool = False,
map_bilstm_to_lstm: bool = False,
float16_quantization: bool = False,
enable_mtk_ops: bool = False,
conv_transpose_with_bias: bool = True,
max_transpose_dims: int = -1,
hybrid_conv: bool = True,
hybrid_int16_lstm: bool = False,
unroll_rnn: bool = False,
separated_rnn_gate_calc: bool = False,
bypass_elementwise_passthrough_constraint: bool = False,
hybrid_gen_single_op_models: bool = False,
hybrid_config: typing.Optional[typing.Dict[str, bool]] = None,
group_tensors: bool = False,
missing_outputs_as_constants: bool = False,
legacy_gelu: bool = False,
) -> None:
""" The TFLiteConverter class
Args:
model (typing.Union[torch.jit.ScriptFunction, torch.jit.ScriptModule, torch.nn.Module]): The input model \
(either traced or non-traced)
dummy_input (typing.Union[torch.Tensor, typing.Iterable[torch.Tensor]]): A viable input to the model
tflite_path (str): Path to use for exporting
input_transpose (typing.Optional[typing.Union[bool, typing.Iterable[bool]]], optional): Whether to \
transpose the input(s). Defaults to None(True for 4d-input, False otherwise).
output_transpose (typing.Optional[typing.Union[bool, typing.Iterable[bool]]], optional): Whether to \
transpose the output(s). Defaults to None(True for 4d-input, False otherwise).
nchw_transpose (bool): Whether to perform nchw->nhwc transposes on input and output tensors. \
`False` is specified, the arguments `input_transpose` and `output_transpose` will be ignored.
dump_jit_model_path (typing.Optional[str]): The path for dumping the jit model. Defaults to None
dump_dummy_input_path (typing.Optional[str]): The path for dumping the dummy input. Defaults to None
dump_config_path (typing.Optional[str]): The path for dumping the json config. Defaults to None
strict_symmetric_check (bool): Strict symmetric quantization checks. Defaults to False
preserve_tensors (bool): Preserve the copies of the intermediate tensors. Defaults to False
optimize (int): The level of graph optimization. Defaults to `GraphOptimizer.ALL_OPTIMIZE`
quantize_target_type (str): Target type for quantization. Defaults to 'uint8'
quantize_input_output_type (str): Input and output type for quantization. Defaults to None (inferred)
hybrid_quantization_from_float (bool): Direct hybrid quantization from a float model. Defaults to False
hybrid_per_channel (bool): Prefer per-channel kernels in hybrid quantization. Defaults to False
hybrid_asymmetric_inputs (bool): Prefer asymmetric inputs while performing hybrid quantization
hybrid_quantize_weight_type (typing.Optional[str]): Quantized weight type for hybrid quantization. \
If it is unset, then the value of `quantize_target_type` will be used. Defaults to None
fuse_quant_dequant (bool): Remove quant and dequant nodes directly connected to i/o nodes. Defaults to False
fuse_input_indices (typing.Optional[typing.List[int]]): Used together with `fuse_quant_dequant`. Indices \
of input nodes to fuse with `Quantize`. Defaults to None (which fuses all inputs available)
fuse_output_indices (typing.Optional[typing.List[int]]): Used together with `fuse_quant_dequant`. Indices \
of output nodes to fuse with `Dequantize`. Defaults to None (which fuses all outputs available)
gc_when_reload (bool): Apply GC when reloading the torchscript into memory
group_conv_rewrite (bool): Rewriting for group [de]convolution. Defaults to False
rewrite_quantizable (bool): Rewriting quantizable ops (e.g. BATCH_MATMUL, SOFTMAX, LOG_SOFTMAX) \
to use quantized kernels. Defaults to False
tflite_micro_rewrite (bool): Rewriting for running on TFLite-micro. Defaults to False
map_bilstm_to_lstm (bool): Translating bidirectional LSTM to TFLite ops with `UnidirectionalLSTM`. \
Defaults to False
float16_quantization (bool): Quantize constants with float32 dtype to floa16 dtype. Defaults to False
enable_mtk_ops (bool): Translating with custom MTK operators. Defaults to False
conv_transpose_with_bias (bool): ConvTranspose ops with bias. Defaults to True
max_transpose_dims (int): Max dimensions for the `Transpose` op. Defaults to -1, which means unlimited
hybrid_conv (bool): Enable hybrid quantization for Conv2d and DepthwiseConv2d. Defaults to True
hybrid_int16_lstm (bool): Enable hybrid int16 quantization for LSTM. Defaults to False
unroll_rnn (bool): Unrolling LSTM (translate LSTM to seperate ops). Defaults to False
separated_rnn_gate_calc (bool): Separated calculation for every gate in RNN. Effective only when \
`unroll_rnn=True`. Defaults to False
bypass_elementwise_passthrough_constraint (bool): Bypass constraints in elementwise passthrough passes. \
Defaults to False
hybrid_gen_single_op_models: Generate both floating point and quantized version of the model for hybrid \
quantizable ops. Defaults to False
group_tensors (bool): Group tensors to save space. Defaults to False
missing_outputs_as_constants (bool): View missing outputs as constants. Defaults to False
legacy_gelu (bool): Fallback to the legacy behaviour for translating gelu. Defaults to False
"""
self.model = model
self.lower_model = None
self.graph = None
self.tensor_map = {}
self.tensor_map_copies = {}
self.common_graph = CommonGraph()
if type(dummy_input) in (tuple, list):
self.dummy_input = dummy_input
else:
self.dummy_input = [dummy_input]
self.flatten_inputs = []
self.tflite_path = tflite_path
self.nchw_transpose = nchw_transpose
if self.nchw_transpose:
self.input_transpose = input_transpose
self.output_transpose = output_transpose
else:
self.input_transpose = False
self.output_transpose = False
self.strict_symmetric_check = strict_symmetric_check
self.dump_jit_model_path = dump_jit_model_path
self.dump_dummy_input_path = dump_dummy_input_path
self.dump_config_path = dump_config_path
self.preserve_tensors = preserve_tensors
self.optimize = optimize
self.hybrid = hybrid_quantization_from_float
self.hybrid_per_channel = hybrid_per_channel
self.hybrid_asymmetric_inputs = hybrid_asymmetric_inputs
self.fuse_quant_dequant = fuse_quant_dequant
self.fuse_input_indices = fuse_input_indices
self.fuse_output_indices = fuse_output_indices
self.gc_when_reload = gc_when_reload
self.group_conv_rewrite = group_conv_rewrite
self.rewrite_quantizable = rewrite_quantizable
self.tflite_micro_rewrite = tflite_micro_rewrite
self.map_bilstm_to_lstm = map_bilstm_to_lstm
self.float16_quantization = float16_quantization
self.enable_mtk_ops = enable_mtk_ops
self.conv_transpose_with_bias = conv_transpose_with_bias
self.max_transpose_dims = max_transpose_dims
self.hybrid_conv = hybrid_conv
self.hybrid_int16_lstm = hybrid_int16_lstm
self.unroll_rnn = unroll_rnn
self.separated_rnn_gate_calc = separated_rnn_gate_calc
self.bypass_elementwise_passthrough_constraint = bypass_elementwise_passthrough_constraint
self.hybrid_gen_single_op_models = hybrid_gen_single_op_models
self.hybrid_config = hybrid_config
self.group_tensors = group_tensors
self.missing_outputs_as_constants = missing_outputs_as_constants
self.legacy_gelu = legacy_gelu
if quantize_target_type == 'uint8':
self.q_type = np.uint8
if self.strict_symmetric_check:
log.warning('Symmetric quantized model with uint8 is unsupported in most backends of TFLite')
elif quantize_target_type == 'int8':
self.q_type = np.int8
elif quantize_target_type == 'int16':
if not self.strict_symmetric_check:
raise AttributeError('Int16 quantization requires strict_symmetric_check=True')
self.q_type = np.int16
else:
raise AttributeError(f'unknown quantize_target_type: {quantize_target_type}, expected: uint8, int8, int16')
if quantize_input_output_type is not None:
assert fuse_quant_dequant, 'Please set fuse_quant_dequant=True, otherwise quantize_input_type is ignored'
assert quantize_input_output_type in (
'int8',
'uint8',
'int16',
), f'unknown quantize_input_output_type: {quantize_input_output_type}, expected: uint8, int8, int16'
if quantize_input_output_type == 'int16' and quantize_target_type != 'int16':
raise AttributeError(
'quantize_input_output_type == \'int16\' and quantize_target_type != \'int16\' is not supported'
)
self.quantize_input_output_type = quantize_input_output_type
if hybrid_quantize_weight_type is None:
hybrid_quantize_weight_type = quantize_target_type
if hybrid_quantize_weight_type == 'uint8':
if self.hybrid:
if self.hybrid_per_channel:
raise AttributeError('Per-channel kernels supports int8 only')
log.warning(
'Unless you are using legacy TFLite (<1.14), please set quantize_target_type to int8 instead'
)
self.hybrid_q_type = np.uint8
elif hybrid_quantize_weight_type == 'int8':
self.hybrid_q_type = np.int8
elif hybrid_quantize_weight_type == 'int16':
self.hybrid_q_type = np.int16
if self.hybrid:
raise AttributeError('Hybrid kernels supports int8 and uint8 only')
if dump_config_path and not dump_jit_model_path:
raise AssertionError("when dump_config_path is set, dump_jit_model_path is required to be set")
self.input_offset = 1
def init_jit_graph(self):
# Multi-GPU modules doesn't support JIT tracing
if isinstance(self.model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
self.model = self.model.module
if not isinstance(self.model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
if hasattr(self.model, 'cpu'):
self.model.cpu()
if hasattr(self.model, 'eval'):
self.model.eval()
with torch.no_grad():
script = torch.jit.trace(self.model, self.dummy_input)
# Remove reference to original model to save memory
self.model = None
# Have to save it once, otherwise something weird happens
if self.dump_jit_model_path is None:
with io.BytesIO() as f:
torch.jit.save(script, f)
f.seek(0)
script = torch.jit.load(f)
else:
jit_model_dir = os.path.abspath(os.path.dirname(self.dump_jit_model_path))
os.makedirs(jit_model_dir, exist_ok=True)
torch.jit.save(script, self.dump_jit_model_path)
if self.gc_when_reload:
import gc
script = None
gc.collect()
script = torch.jit.load(self.dump_jit_model_path)
self.model = script
if isinstance(self.model, torch.jit.ScriptFunction):
self.input_offset = 0
if self.dump_dummy_input_path is not None:
dummy_arrs = list(map(lambda x: x.detach().cpu().numpy(), self.dummy_input))
np.savez(self.dump_dummy_input_path, *dummy_arrs)
if self.dump_config_path is not None:
generate_converter_config(
self.dummy_input,
[],
self.input_transpose,
[],
self.dump_jit_model_path,
self.tflite_path,
self.dump_config_path,
)
def init_lowered_module(self):
assert (
isinstance(self.model, torch.jit.ScriptFunction)
or self.model.training is False
or str(next(self.model.graph.inputs()).type()) == '__torch__.PlaceholderModule'
), (
'Model is in training mode. Please run `model.eval()` before model conversion. If you are passing in a'
' TorchScript model, make sure you use `torch.jit.save` to dump the model to disk and then load it using'
' `torch.jit.load`.'
)
graph = self.model.graph
# Inline everything
torch._C._jit_pass_inline(graph)
# Remove fork/wait nodes
torch._C._jit_pass_inline_fork_wait(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_lower_all_tuples(graph)
# we record now record some ops like ones/zeros
# into a trace where we previously recorded constants
# use constant prop to maintain our current level of onnx support
# without implementing symbolics for all of them
torch._C._jit_pass_constant_propagation(graph)
# _split_tensor_list_constants(graph, graph)
# run dce to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_canonicalize_graph_fuser_ops(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_fuse_addmm(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lower_all_tuples(graph)
self.graph = graph
log.debug('Lowered graph:')
log.debug(self.graph)
def init_flatten_inputs(self):
self.flatten_inputs.clear()
for t in self.dummy_input:
if isinstance(t, (list, tuple)):
for rt in t:
self.flatten_inputs.append(rt)
else:
self.flatten_inputs.append(t)
def init_input_transpose(self):
input_transpose = self.input_transpose
if type(input_transpose) not in (tuple, list):
input_transpose = [input_transpose] * len(self.flatten_inputs)
for i, t in enumerate(self.flatten_inputs):
if input_transpose[i] is None:
if isinstance(t, torch.Tensor):
input_transpose[i] = t.dim() == 4
else:
input_transpose[i] = False
self.input_transpose = input_transpose
def init_common_graph(self):
graph_inputs = [x.debugName() for x in list(self.graph.inputs())][self.input_offset :]
graph_outputs = [x.debugName() for x in list(self.graph.outputs())]
self.common_graph.inputs.extend(graph_inputs)
self.common_graph.outputs.extend(graph_outputs)
self.common_graph.input_transpose.extend(self.input_transpose)
self.common_graph.output_transpose = self.output_transpose
tensors = []
for i, node in enumerate(graph_inputs):
tensors.append(
Tensor(
self.flatten_inputs[i],
node,
has_buffer=False,
asymmetric=not self.strict_symmetric_check,
q_type=self.q_type,
)
)
self.common_graph.add_nodes(tensors, ExtendedOperator.INPUT_NODE)
def init_inputs(self):
graph_inputs = [x.debugName() for x in list(self.graph.inputs())]
for i, node in enumerate(graph_inputs):
if self.input_offset > 0 and i == 0:
self.tensor_map[graph_inputs[i]] = self.model
else:
self.tensor_map[graph_inputs[i]] = self.flatten_inputs[i - self.input_offset]
def unsupported_operations(self, unique=True) -> typing.List[str]:
"""Returns unsupported operations in the graph"""
if self.graph is None:
self.init_lowered_module()
all_nodes = list(self.graph.nodes())
ops = []
for node in all_nodes:
k = node.kind()
converter_type = OPERATOR_CONVERTER_DICT.get(k, None)
if converter_type is None:
ops.append(k)
if unique:
return list(set(ops))
else:
return ops
def init_operations(self):
log.debug('Initialize operators...')
node_queue = collections.deque(self.graph.nodes())
scope_map = {}
current_scope = None
while node_queue:
node = node_queue.popleft()
k = node.kind()
output_tensors = []
converter_type = OPERATOR_CONVERTER_DICT.get(k, NoTrackOperator)
converter = converter_type(
node,
self.tensor_map,
current_scope,
not self.strict_symmetric_check,
self.q_type,
self.hybrid_q_type,
self.map_bilstm_to_lstm,
self.enable_mtk_ops,
self.hybrid_asymmetric_inputs,
self.unroll_rnn,
self.separated_rnn_gate_calc,
self.conv_transpose_with_bias,
self.legacy_gelu,
)
# Don't track the operator if all the input nodes are not tracked unless it has custom implementation
# (e.g prim::* ops)
if converter_type.run == NoTrackOperator.run and converter_type != NoTrackOperator:
no_track_flag = True
for n in converter.input_names:
if self.common_graph.has_nested_names(n):
nested_names = self.common_graph.get_list_expanded_names(n)
for x in nested_names:
if x in self.common_graph.tensor_map and self.common_graph.tensor_map[x].buffer is None:
no_track_flag = False
break
elif n in self.common_graph.tensor_map and self.common_graph.tensor_map[n].buffer is None:
no_track_flag = False
break
if no_track_flag:
if converter_type == ATenDequantizeOperator:
converter_type = TrackQParamsOperator
elif converter_type == ATenQuantizePerTensorOperator:
converter_type = TrackRevQParamsOperator
else:
converter_type = NoTrackOperator
converter = converter_type(
node,
self.tensor_map,
current_scope,
not self.strict_symmetric_check,
self.q_type,
self.hybrid_q_type,
self.map_bilstm_to_lstm,
self.enable_mtk_ops,
self.hybrid_asymmetric_inputs,
self.unroll_rnn,
self.separated_rnn_gate_calc,
self.conv_transpose_with_bias,
self.legacy_gelu,
)
if k != 'prim::Constant':
log.debug(f'{k} {converter.input_names} -> {converter.output_names} {converter_type.__name__}')
# Don't fetch attrs and schemas for non-tracking nodes
if converter_type not in (NoTrackOperator, TrackRevQParamsOperator, TrackQParamsOperator):
try:
attrs = converter.fetch_all_attrs(node)
except StopIteration:
attrs = None
args = converter.fetch_annotated_args(node)
else:
attrs = None
args = None
converter.parse(node, attrs, args, self.common_graph)
outputs = converter.output_names
new_nodes = converter.output_nodes
if output_tensors is not None:
output_tensors.extend(converter.get_output_tensors())
if len(new_nodes) > 0:
node_queue.extendleft(reversed(new_nodes))
if k == 'prim::PythonOp':
s = node.scopeName()
scope_map.setdefault(s, 0)
scope_map[s] += 1
current_scope = f'{s}_{scope_map[s]}'
converter.prepare_scope_tensors(node, attrs, args, self.common_graph, current_scope)
elif k == 'prim::Return':
current_scope = None
assert len(output_tensors) == len(outputs)
for t, name in zip(output_tensors, outputs):
self.tensor_map[name] = t
if self.preserve_tensors and isinstance(t, torch.Tensor):
self.tensor_map_copies[name] = t.detach().clone()
def __try_infer_type(self, params):
try:
inferred = torch._C._jit_try_infer_type(params)
if hasattr(inferred, 'type'):
return inferred.type().annotation_str
finally:
return str(inferred)
def __unpack_params(self, params):
return NoTrackOperator.unpack_params(None, params)
def convert(self):
"""Converts the model to the TFLite format
Raises:
Exception: If unsupported ops are found, an Exception will be raised
"""
self.init_flatten_inputs()
self.init_input_transpose()
self.init_jit_graph()
self.init_lowered_module()
self.init_common_graph()
self.init_inputs()
self.init_operations()
unsupported_ops = self.unsupported_operations()
if len(unsupported_ops) > 0:
log.error(f'Unsupported ops: {", ".join(unsupported_ops)}')
raise Exception("Cannot continue due to fatal error")
else:
optimizer = GraphOptimizer(
self.common_graph,
self.optimize,
self.fuse_quant_dequant,
self.group_conv_rewrite,
self.rewrite_quantizable,
self.tflite_micro_rewrite,
self.quantize_input_output_type,
self.fuse_input_indices,
self.fuse_output_indices,
self.max_transpose_dims,
self.bypass_elementwise_passthrough_constraint,
self.group_tensors,
self.conv_transpose_with_bias,
self.hybrid_int16_lstm,
)
optimizer.optimize()
self.output_transpose = self.common_graph.output_transpose
if self.hybrid:
quantizer = HybridQuantizer(
self.common_graph,
self.hybrid_asymmetric_inputs,
self.hybrid_q_type,
self.hybrid_per_channel,
self.hybrid_conv,
self.hybrid_int16_lstm,
self.hybrid_gen_single_op_models,
self.hybrid_config,
)
quantizer.quantize()
optimizer.cleanup_dead_nodes()
if self.float16_quantization:
quantizer = HalfQuantizer(self.common_graph)
quantizer.quantize()
optimizer.cleanup_dead_nodes()
versioner = OPVersioner(self.common_graph)
versioner.process()
if self.missing_outputs_as_constants:
tensors = []
for output_name in self.common_graph.outputs:
if output_name not in self.common_graph.tensor_map:
tensors.append(
Tensor(
self.tensor_map[output_name],
output_name,
has_buffer=True,
asymmetric=not self.strict_symmetric_check,
q_type=self.q_type,
)
)
self.common_graph.add_nodes(tensors, ExtendedOperator.CONSTANT_NODE)
self.common_graph.add_outputs([t.name for t in tensors])
self.common_graph.convert(self.tflite_path)
log.info(f'Generated model saved to {self.tflite_path}')
def visualize(self, hide_constants=True):
"""Visualize the TinyNeuralNetwork Graph
Args:
hide_constants (bool, optional): Hide the constant nodes in the graph. Defaults to True.
"""
self.common_graph.visualize(hide_constants)
def get_outputs(self):
"""Returns the output of the model, which is evaluated via tracing nodes one by one"""
outputs = []
for name in self.common_graph.outputs:
outputs.append(self.tensor_map[name])
return outputs
def get_value(self, name, default_val=None):
"""Returns the output according to the name of the node. If the name doesn't exist, `default_val` is returned"""
if self.preserve_tensors:
val = self.tensor_map_copies.get(name, default_val)
else:
val = self.tensor_map.get(name, default_val)
type_ = self.__try_infer_type(val)
if type_.endswith('PackedParamsBase'):
return self.__unpack_params(val)
return val
def tensor_names(self) -> typing.List[str]:
"""Returns the all the names of the intermediate tensors
Returns:
typing.List[str]: The names of the intermediate tensors
"""
if self.preserve_tensors:
return list(self.tensor_map_copies.keys())
else:
return list(self.tensor_map.keys())
def inputs_for_tflite(self) -> typing.List[np.ndarray]:
"""Prepare inputs for the TFLite backend
Returns:
typing.List[np.ndarray]: The input tensors
"""
arrs = []
for t, trans in zip(self.dummy_input, self.input_transpose):
arr = t.detach().clone().numpy()
if trans:
arr = np.transpose(arr, (0, 2, 3, 1))
arrs.append(arr)
return arrs