tinynn/converter/operators/tflite/transformable.py (550 lines of code) (raw):

from abc import abstractmethod from .base import BaseOperator, QuantizationParameters, Tensor from .custom import MTKTransposeConvOperator from . import generated_ops as tfl_ops from ..base import ExtendedOperator from ...schemas.tflite import schema_generated as tflite import typing import torch import warnings import numpy as np class TransformableOperator(BaseOperator): def __init__(self, op: int, inputs: typing.List['Tensor'], outputs: typing.List['Tensor'], op_version: int): super().__init__(op, inputs, outputs, op_version=op_version) self.attr_count = 0 self.transform_count = 0 @abstractmethod def transform(self): pass def create_attr_tensor(self, tensor, name=None, quantization=None): if name is None: if self.attr_count == 0: name = self.outputs[0].name + '_te_attr' else: name = self.outputs[0].name + f'_te_attr_{self.attr_count}' self.attr_count += 1 return Tensor(tensor, name, has_buffer=True, quantization=quantization) def create_transform_tensor(self, tensor, name=None, quantization=None): if name is None: if self.transform_count == 0: name = self.outputs[0].name + '_te_transform' else: name = self.outputs[0].name + f'_te_transform_{self.transform_count}' self.transform_count += 1 return Tensor(tensor, name, has_buffer=False, quantization=quantization) def wrap_ops_with_nhwc_nchw_transposes( self, ops: typing.List[tfl_ops.BaseOperator], input_idx: int = 0, output_idx: int = 0 ) -> typing.List[tfl_ops.BaseOperator]: orig_input = ops[0].inputs[input_idx] orig_output = ops[-1].outputs[output_idx] if orig_input.tensor.ndim == 4: nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32') nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32') elif orig_input.tensor.ndim == 5: nhwc2nchw_perm = np.array([0, 4, 1, 2, 3], dtype='int32') nchw2nhwc_perm = np.array([0, 2, 3, 4, 1], dtype='int32') else: assert False, f'Don\'t know how to wrap tranposes for {orig_input.tensor.ndim}d tensors' nhwc2nchw_perm_tensor = self.create_attr_tensor(nhwc2nchw_perm) nchw2nhwc_perm_tensor = self.create_attr_tensor(nchw2nhwc_perm) new_input = self.create_transform_tensor( np.transpose(orig_input.tensor, nchw2nhwc_perm), quantization=orig_input.quantization ) new_output = self.create_transform_tensor( np.transpose(orig_output.tensor, nchw2nhwc_perm), quantization=orig_output.quantization ) nchw2nhwc_transpose = tfl_ops.TransposeOperator([orig_input, nchw2nhwc_perm_tensor], [new_input]) nhwc2nchw_transpose = tfl_ops.TransposeOperator([new_output, nhwc2nchw_perm_tensor], [orig_output]) nchw2nhwc_transpose.extra_hints['direction'] = 'up' nhwc2nchw_transpose.extra_hints['direction'] = 'down' ops[0].inputs[input_idx] = new_input ops[-1].outputs[output_idx] = new_output return [nchw2nhwc_transpose] + ops + [nhwc2nchw_transpose] class BatchNormOperator(TransformableOperator): input_index = 0 weight_index = 1 bias_index = 2 running_mean_index = 3 running_variance_index = 4 output_index = 0 def __init__( self, inputs: typing.List['Tensor'], outputs: typing.List['Tensor'], eps: float, quantization: typing.Optional[QuantizationParameters] = None, fusedActivationFunction=tflite.ActivationFunctionType.NONE, ): super().__init__(ExtendedOperator.BATCH_NORM, inputs, outputs, 1) self.eps = eps self.fusedActivationFunction = fusedActivationFunction def transform(self, graph_converter, mapping): assert all((x.buffer is not None for x in self.inputs[1:])) w, b, mean, var = [ self.inputs[i] for i in (self.weight_index, self.bias_index, self.running_mean_index, self.running_variance_index) ] inv = 1 / np.sqrt(var.tensor + self.eps) new_w = inv * w.tensor new_b = b.tensor - mean.tensor * new_w inp = self.inputs[0] new_shape = [1] + [new_w.shape[0]] + [1] * (inp.tensor.ndim - 2) new_w = new_w.reshape(new_shape) new_b = new_b.reshape(new_shape) weight = self.create_attr_tensor(new_w) bias = self.create_attr_tensor(new_b) new_inp = inp if inp.quantization is not None: new_inp = self.create_transform_tensor(inp.tensor) graph_converter.add_operator(tfl_ops.DequantizeOperator([inp], [new_inp])) mul_out = self.create_transform_tensor(new_inp.tensor * weight.tensor) graph_converter.add_operator(tfl_ops.MulOperator([new_inp, weight], [mul_out])) if inp.quantization is not None: add_out = self.create_transform_tensor(mul_out.tensor + bias.tensor) else: add_out = self.outputs[self.output_index] graph_converter.add_operator( tfl_ops.AddOperator([mul_out, bias], [add_out], fusedActivationFunction=self.fusedActivationFunction), transform=True, ) if inp.quantization is not None: quant_out = self.outputs[self.output_index] graph_converter.add_operator(tfl_ops.QuantizeOperator([add_out], [quant_out]), transform=True) graph_converter.try_restore_edges(mapping) class GenericConvOperator(TransformableOperator): input_index = 0 weight_index = 1 bias_index = 2 output_index = 0 stride: typing.List[int] padding: typing.List[int] dilation: typing.List[int] transpose: bool output_padding: typing.List[int] groups: int fusedActivationFunction: tflite.ActivationFunctionType def __init__( self, inputs: typing.List['Tensor'], outputs: typing.List['Tensor'], stride: typing.List[int], padding: typing.List[int], dialation: typing.List[int], output_padding: typing.List[int], groups: int, fusedActivationFunction=tflite.ActivationFunctionType.NONE, ): super().__init__(ExtendedOperator.GENERIC_CONV, inputs, outputs, 1) self.stride = stride self.padding = padding self.dilation = dialation self.output_padding = output_padding self.groups = groups self.fusedActivationFunction = fusedActivationFunction def transform(self, graph_converter, mapping): input_tensor = self.inputs[0] weight_tensor = self.inputs[1] input_dim = len(input_tensor.shape) weight_dim = len(weight_tensor.shape) prev_ops = [] next_ops = [] if weight_dim == 3 or input_dim == 3: reshape_input_size = 1 reshape_output_size = 1 if weight_dim == 3: self.stride.insert(0, 1) self.padding.insert(0, 0) self.dilation.insert(0, 1) self.output_padding.insert(0, 0) reshape_input_size = 2 reshape_outputs = [ self.create_transform_tensor( np.expand_dims(t.tensor, 2), name=f'{self.outputs[0].name}_{t.name}_4d_input', quantization=t.quantization, ) for t in self.inputs[:reshape_input_size] ] reshape_attrs = [self.create_attr_tensor(np.array(t.shape, dtype='int32')) for t in reshape_outputs] reshape_ops = [ tfl_ops.ReshapeOperator([old, attr], [new], attr.tensor) for old, new, attr in zip(self.inputs[:reshape_input_size], reshape_outputs, reshape_attrs) ] for op in reshape_ops: op.extra_hints['direction'] = 'up' prev_ops.extend(reshape_ops) conv_outputs = [ self.create_transform_tensor( np.expand_dims(self.outputs[i].tensor, 2), name=f'{self.outputs[i].name}_4d_output', quantization=self.outputs[i].quantization, ) for i in range(reshape_output_size) ] conv_attrs = [ self.create_attr_tensor(np.array(t.shape, dtype='int32')) for t in self.outputs[:reshape_output_size] ] conv_ops = [ tfl_ops.ReshapeOperator([old, attr], [new], attr.tensor) for old, new, attr in zip(conv_outputs, self.outputs[:reshape_output_size], conv_attrs) ] for op in conv_ops: op.extra_hints['direction'] = 'down' next_ops.extend(conv_ops) self.inputs = reshape_outputs + self.inputs[reshape_input_size:] self.outputs = conv_outputs + self.outputs[reshape_output_size:] weight_tensor = self.inputs[1] elif weight_dim not in (4, 5): assert False, "Only Conv[Transpose]1d/2d/3d is supported" if weight_tensor.shape[1] == 1 and weight_tensor.shape[0] == self.groups: if weight_dim in (3, 4): conv_op = tfl_ops.DepthwiseConv2dOperator( self.inputs, self.outputs, strideH=self.stride[0], strideW=self.stride[1], depthMultiplier=1, dilationHFactor=self.dilation[0], dilationWFactor=self.dilation[1], fusedActivationFunction=self.fusedActivationFunction, padding=tflite.Padding.VALID, ) else: assert False, "Only DepthwiseConv1d/2d is supported" else: if input_tensor.shape[1] != weight_tensor.shape[1]: warnings.warn( 'Group conv is not supported if official tflite interpreter is used. If that is the case for you,' ' plese pass in `group_conv_rewrite=True`. If you want to run the model with TFLite micro, then you' ' may also need to pass in `tflite_micro_rewrite=True`' ) if weight_dim in (3, 4): conv_op = tfl_ops.Conv2dOperator( self.inputs, self.outputs, strideH=self.stride[0], strideW=self.stride[1], dilationHFactor=self.dilation[0], dilationWFactor=self.dilation[1], fusedActivationFunction=self.fusedActivationFunction, padding=tflite.Padding.VALID, ) else: conv_op = tfl_ops.Conv3dOperator( self.inputs, self.outputs, strideD=self.stride[0], strideH=self.stride[1], strideW=self.stride[2], dilationDFactor=self.dilation[0], dilationHFactor=self.dilation[1], dilationWFactor=self.dilation[2], fusedActivationFunction=self.fusedActivationFunction, padding=tflite.Padding.VALID, ) ops = self.wrap_ops_with_nhwc_nchw_transposes([conv_op]) conv_op = ops[1] # Pad handling if sum(self.padding) > 0: if weight_dim in (3, 4): pad_h = self.padding[0] pad_w = self.padding[1] pad = [[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]] else: pad_d = self.padding[0] pad_h = self.padding[1] pad_w = self.padding[2] pad = [[0, 0], [pad_d, pad_d], [pad_h, pad_h], [pad_w, pad_w], [0, 0]] pad_tensor = self.create_attr_tensor(np.array(pad, dtype='int32')) pad_input = ops[0].outputs[0] pad_array = np.pad(pad_input.tensor, pad) pad_out = self.create_transform_tensor(pad_array, quantization=pad_input.quantization) ops[1].inputs[0] = pad_out pad_op = tfl_ops.PadOperator([pad_input, pad_tensor], [pad_out]) ops.insert(1, pad_op) # Weight handling weight = conv_op.inputs[1] if conv_op.op.code == tflite.BuiltinOperator.DEPTHWISE_CONV_2D: nchw2chwn_perm = np.array([1, 2, 3, 0], dtype='int32') nchw2chwn_perm_tensor = self.create_attr_tensor(nchw2chwn_perm) weight_q = weight.quantization if weight_q is not None and weight_q.dim is not None: new_dim = np.nonzero(nchw2chwn_perm == weight_q.dim)[0][0] weight_q = QuantizationParameters(weight_q.scale, weight_q.zero_point, new_dim) reordered_weight = self.create_transform_tensor( np.transpose(weight.tensor, nchw2chwn_perm), quantization=weight_q ) conv_op.inputs[1] = reordered_weight reorder_op = tfl_ops.TransposeOperator([weight, nchw2chwn_perm_tensor], [reordered_weight]) else: if weight_dim in (3, 4): nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32') nchw2nhwc_perm_tensor = self.create_attr_tensor(nchw2nhwc_perm) else: nchw2nhwc_perm = np.array([2, 3, 4, 1, 0], dtype='int32') nchw2nhwc_perm_tensor = self.create_attr_tensor(nchw2nhwc_perm) weight_q = weight.quantization if weight_q is not None and weight_q.dim is not None: new_dim = np.nonzero(nchw2nhwc_perm == weight_q.dim)[0][0] weight_q = QuantizationParameters(weight_q.scale, weight_q.zero_point, new_dim) reordered_weight = self.create_transform_tensor( np.transpose(weight.tensor, nchw2nhwc_perm), quantization=weight_q ) conv_op.inputs[1] = reordered_weight reorder_op = tfl_ops.TransposeOperator([weight, nchw2nhwc_perm_tensor], [reordered_weight]) ops.insert(1, reorder_op) # Bias handling kernel_num = self.inputs[1].shape[0] if conv_op.op.code in (tflite.BuiltinOperator.DEPTHWISE_CONV_2D, tflite.BuiltinOperator.CONV_3D): kernel_num = self.inputs[1].shape[-1] if len(conv_op.inputs) == 2 or conv_op.inputs[2] is None: if conv_op.inputs[0].dtype == np.dtype('float32'): bias = np.zeros((kernel_num,), dtype='float32') q_args = None else: bias = np.zeros((kernel_num,), dtype='int32') per_tensor = weight_tensor.quantization.dim is None # Bias handling if per_tensor: bias_scale = input_tensor.quantization.scale * weight_tensor.quantization.scale bias_zero_point = 0 bias_dim = None else: bias_scale = [input_tensor.quantization.scale * s for s in weight_tensor.quantization.scale] bias_zero_point = [0] * len(bias_scale) bias_dim = 0 q_args = QuantizationParameters(bias_scale, bias_zero_point, bias_dim) conv_op.inputs.append(self.create_attr_tensor(bias, quantization=q_args)) elif conv_op.inputs[2].shape[0] != kernel_num and conv_op.inputs[2].shape[0] == 1: if conv_op.inputs[0].dtype == np.float32: bias = torch.tensor([conv_op.inputs[2][0]] * kernel_num, dtype='float32') else: bias = torch.tensor([conv_op.inputs[2][0]] * kernel_num, dtype='int32') conv_op.inputs[2] = self.create_attr_tensor(bias) ops = prev_ops + ops + next_ops for op in ops: graph_converter.add_operator(op, transform=True) graph_converter.try_restore_edges(mapping) for op in ops[:-1]: output_name = op.outputs[0].name node_name = graph_converter.tensor_node_map[output_name] node = graph_converter.graph.vs.find(name=node_name) assert node.outdegree() > 0, ( 'The following node should be a part of the transformable node, but the outdegree of' f' it is zero. {node}' ) next_node = graph_converter.graph.vs[node.out_edges()[0].target] assert next_node['node_type'] != ExtendedOperator.CONSTANT_NODE class GenericTransposeConvOperator(TransformableOperator): input_index = 0 weight_index = 1 bias_index = 2 output_index = 0 stride: typing.List[int] padding: typing.List[int] dilation: typing.List[int] transpose: bool output_padding: typing.List[int] groups: int enable_mtk_ops: bool conv_transpose_with_bias: bool fusedActivationFunction: tflite.ActivationFunctionType def __init__( self, inputs: typing.List['Tensor'], outputs: typing.List['Tensor'], stride: typing.List[int], padding: typing.List[int], dilation: typing.List[int], output_padding: typing.List[int], groups: int, enable_mtk_ops: bool = False, conv_transpose_with_bias: bool = True, fusedActivationFunction=tflite.ActivationFunctionType.NONE, ): super().__init__(ExtendedOperator.GENERIC_DECONV, inputs, outputs, 1) self.stride = stride self.padding = padding self.dilation = dilation self.output_padding = output_padding self.groups = groups self.enable_mtk_ops = enable_mtk_ops self.conv_transpose_with_bias = conv_transpose_with_bias self.fusedActivationFunction = fusedActivationFunction def transform(self, graph_converter, mapping): input_tensor = self.inputs[0] weight_tensor = self.inputs[1] output_tensor = self.outputs[0] input_dim = len(input_tensor.shape) weight_dim = len(weight_tensor.shape) prev_ops = [] next_ops = [] if weight_dim == 3 or input_dim == 3: self.stride.insert(0, 1) self.padding.insert(0, 0) self.dilation.insert(0, 1) self.output_padding.insert(0, 0) reshape_outputs = [ self.create_transform_tensor( np.expand_dims(t.tensor, 2), name=f'{self.outputs[0].name}_{t.name}_4d_input', quantization=t.quantization, ) for t in self.inputs[:2] ] reshape_attrs = [self.create_attr_tensor(np.array(t.shape, dtype='int32')) for t in reshape_outputs] reshape_ops = [ tfl_ops.ReshapeOperator([old, attr], [new], attr.tensor) for old, new, attr in zip(self.inputs[:2], reshape_outputs, reshape_attrs) ] for op in reshape_ops: op.extra_hints['direction'] = 'up' if weight_dim == 3 and input_dim == 3: prev_ops.extend(reshape_ops) elif weight_dim == 3: prev_ops.append(reshape_ops[1]) else: prev_ops.append(reshape_ops[0]) conv_outputs = [ self.create_transform_tensor( np.expand_dims(self.outputs[0].tensor, 2), name=f'{self.outputs[0].name}_4d_output', quantization=self.outputs[0].quantization, ) ] conv_attrs = [self.create_attr_tensor(np.array(t.shape, dtype='int32')) for t in self.outputs[:1]] conv_ops = [ tfl_ops.ReshapeOperator([old, attr], [new], attr.tensor) for old, new, attr in zip(conv_outputs, self.outputs[:1], conv_attrs) ] for op in conv_ops: op.extra_hints['direction'] = 'down' next_ops.extend(conv_ops) if weight_dim == 3 and input_dim == 3: self.inputs = reshape_outputs + self.inputs[2:] elif weight_dim == 3: self.inputs = self.inputs[0:1] + reshape_outputs[1:2] + self.inputs[1:] else: self.inputs = reshape_outputs[0:1] + self.inputs[1:] self.outputs = conv_outputs + self.outputs[1:] weight_tensor = self.inputs[1] elif weight_dim not in (4, 5): assert False, "Only Conv[Transpose]1d/2d/3d is supported" if output_tensor.shape[1] != weight_tensor.shape[1]: warnings.warn( 'Group transposed conv is not supported if official tflite interpreter is used. If that is the case' ' for you, plese pass in `group_conv_rewrite=True`. If you want to run the model with TFLite micro,' ' then you may also need to pass in `tflite_micro_rewrite=True`' ) if weight_dim in (3, 4): assert all((x == 1 for x in self.dilation)), "Only dilation=1 is supported for conv_transpose2d" if self.enable_mtk_ops: conv_op = MTKTransposeConvOperator( self.inputs[:2][::-1], self.outputs, depth_multiplier=1, dilation_height_factor=self.dilation[0], dilation_width_factor=self.dilation[1], padding_type=tflite.Padding.VALID, stride_height=self.stride[0], stride_width=self.stride[1], ) else: conv_op = tfl_ops.TransposeConvOperator( self.inputs[:2][::-1], self.outputs, strideH=self.stride[0], strideW=self.stride[1], padding=tflite.Padding.VALID, fusedActivationFunction=self.fusedActivationFunction, ) else: conv_op = tfl_ops.Conv3dTransposeOperator( self.inputs[:2][::-1], self.outputs, strideD=self.stride[0], strideH=self.stride[1], strideW=self.stride[2], dilationDFactor=self.dilation[0], dilationHFactor=self.dilation[1], dilationWFactor=self.dilation[2], padding=tflite.Padding.VALID, fusedActivationFunction=self.fusedActivationFunction, ) ops = self.wrap_ops_with_nhwc_nchw_transposes([conv_op], input_idx=1) # Pad handling output_shape = conv_op.outputs[0].shape if sum(self.padding) > 0: if weight_dim in (3, 4): pad_h = self.padding[0] pad_w = self.padding[1] start = np.array([0, pad_h, pad_w, 0], dtype='int32') pad_sizes = ((0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)) else: pad_d = self.padding[0] pad_h = self.padding[1] pad_w = self.padding[2] start = np.array([0, pad_d, pad_h, pad_w, 0], dtype='int32') pad_sizes = ((0, 0), (pad_d, pad_d), (pad_h, pad_h), (pad_w, pad_w), (0, 0)) size = np.array(ops[1].outputs[0].shape, dtype='int32') start_tensor = self.create_attr_tensor(start) size_tensor = self.create_attr_tensor(size) slice_out = ops[1].outputs[0] pad_array = np.pad(self.outputs[0].tensor, pad_sizes) slice_input = self.create_transform_tensor(pad_array, quantization=self.outputs[0].quantization) ops[1].outputs[0] = slice_input slice_op = tfl_ops.SliceOperator([slice_input, start_tensor, size_tensor], [slice_out]) output_shape = slice_input.shape ops.insert(2, slice_op) # Output shape handling output_shape_tensor = self.create_attr_tensor(np.array(output_shape, dtype='int32')) conv_op.inputs.insert(0, output_shape_tensor) # Weight handling weight = conv_op.inputs[1] if weight_dim in (3, 4): nchw2chwn_perm = np.array([1, 2, 3, 0], dtype='int32') else: nchw2chwn_perm = np.array([2, 3, 4, 1, 0], dtype='int32') nchw2chwn_perm_tensor = self.create_attr_tensor(nchw2chwn_perm) reordered_weight = self.create_transform_tensor( np.transpose(weight.tensor, nchw2chwn_perm), quantization=weight.quantization ) conv_op.inputs[1] = reordered_weight reorder_op = tfl_ops.TransposeOperator([weight, nchw2chwn_perm_tensor], [reordered_weight]) ops.insert(1, reorder_op) # Bias handling if self.enable_mtk_ops or self.conv_transpose_with_bias: kernel_num = output_tensor.shape[1] if len(self.inputs) > 2 and self.inputs[2].shape[0] != kernel_num and self.inputs[2].shape[0] == 1: if conv_op.inputs[-1].dtype == np.float32: bias = torch.tensor([self.inputs[2][0]] * kernel_num, dtype='float32') else: bias = torch.tensor([self.inputs[2][0]] * kernel_num, dtype='int32') conv_op.inputs.append(self.create_attr_tensor(bias)) else: if len(self.inputs) == 2 or self.inputs[2] is None: if conv_op.inputs[-1].dtype == np.dtype('float32'): bias = np.zeros((kernel_num,), dtype='float32') q_args = None else: bias = np.zeros((kernel_num,), dtype='int32') else: bias = self.inputs[2].tensor q_args = None if bias.dtype != np.dtype('float32'): per_tensor = weight_tensor.quantization.dim is None # Bias handling if per_tensor: bias_scale = input_tensor.quantization.scale * weight_tensor.quantization.scale bias_zero_point = 0 bias_dim = None else: bias_scale = [input_tensor.quantization.scale * s for s in weight_tensor.quantization.scale] bias_zero_point = [0] * len(bias_scale) bias_dim = 0 q_args = QuantizationParameters(bias_scale, bias_zero_point, bias_dim) conv_op.inputs.append(self.create_attr_tensor(bias, quantization=q_args)) else: if len(self.inputs) > 2 and self.inputs[2] is not None: bias_tensor = self.inputs[2] add_out = ops[-2].outputs[0] bias_transform = self.create_transform_tensor( add_out.tensor.copy(), quantization=self.outputs[0].quantization ) ops[-2].outputs[0] = bias_transform ops.insert(len(ops) - 1, tfl_ops.AddOperator([bias_transform, bias_tensor], [add_out])) ops = prev_ops + ops + next_ops for op in ops: graph_converter.add_operator(op) graph_converter.try_restore_edges(mapping)