tinynn/converter/operators/op_version.py (294 lines of code) (raw):

from ..schemas.tflite import schema_generated as tfl_schema from . import tflite as tfl from .base import ExtendedOperator from .graph import CommonGraph class OPVersioner(object): """Sets the version of the OPs in the computation graph""" def __init__(self, graph: CommonGraph) -> None: """Constructs an OPVersioner object Args: graph (CommonGraph): The computation graph """ self.graph = graph def process(self): """The main process function for the whole graph""" for node in self.graph.graph.vs: if node['node_type'] >= 0: self.process_op(node['op']) def process_op(self, op: tfl.BaseOperator): """Sets the version of the OP Args: op (tfl.BaseOperator): The operator to be processed """ # Translated from `GetBuiltinOperatorVersion` in # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/versioning/op_version.cc if op.op.code == ExtendedOperator.CONV_2D: if ( str(op.inputs[0].dtype) == 'int8' and str(op.inputs[1].dtype) == 'int8' and str(op.outputs[0].dtype) == 'int8' ): op.op.version = 3 elif ( str(op.inputs[0].dtype) == 'float32' and str(op.inputs[1].dtype) == 'int8' and str(op.outputs[0].dtype) == 'float32' ): op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.DEPTHWISE_CONV_2D: if ( str(op.inputs[0].dtype) == 'float32' and str(op.inputs[1].dtype) == 'int8' and str(op.outputs[0].dtype) == 'float32' ): op.op.version = 4 elif ( str(op.inputs[0].dtype) == 'int8' and str(op.inputs[1].dtype) == 'int8' and str(op.outputs[0].dtype) == 'int8' ): op.op.version = 3 elif op.dilationHFactor != 1 or op.dilationWFactor != 1: op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.FAKE_QUANT: if op.narrowRange: op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.FULLY_CONNECTED: if len(op.inputs) == 2: op.op.version = 6 elif op.keepNumDims: op.op.version = 5 elif ( str(op.inputs[0].dtype) == 'int8' and str(op.inputs[1].dtype) == 'int8' and str(op.outputs[0].dtype) == 'int8' ): op.op.version = 4 elif ( str(op.inputs[0].dtype) == 'float32' and str(op.inputs[1].dtype) == 'int8' and str(op.outputs[0].dtype) == 'float32' ): op.op.version = 3 elif op.weightsFormat == tfl_schema.FullyConnectedOptionsWeightsFormat.SHUFFLED4x16INT8: op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.GATHER: if str(op.inputs[0].dtype) == 'bool': op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.SVDF: if str(op.inputs[0].dtype) == 'int8': op.op.version = 3 elif ( str(op.inputs[0].dtype) == 'float32' and str(op.inputs[1].dtype) == 'int8' and str(op.outputs[0].dtype) == 'float32' ): op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.MUL: if ( op.inputs[0].quantization is not None and op.inputs[1].quantization is not None and op.outputs[0].quantization is not None and op.inputs[0].quantization.scale != 0.0 and op.inputs[1].quantization.scale != 0.0 and op.outputs[0].quantization.scale != 0.0 and op.inputs[0].quantization.scale * op.inputs[1].quantization.scale / op.outputs[0].quantization.scale >= 1.0 ): op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.TRANSPOSE: if len(op.inputs[0].shape) > 4: op.op.version = 4 elif str(op.inputs[0].dtype) == 'bool': op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.TRANSPOSE_CONV: if str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.LSTM: if ( op.kernelType == tfl_schema.LSTMKernelType.FULL and str(op.inputs[0].dtype) == 'float32' and str(op.inputs[2].dtype) == 'int8' and str(op.outputs[0].dtype) == 'float32' ): op.op.version = 3 elif op.kernelType == tfl_schema.LSTMKernelType.BASIC: op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: if ( str(op.inputs[0].dtype) == 'float32' and str(op.inputs[2].dtype) == 'int8' and str(op.outputs[0].dtype) == 'float32' ): op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.SPLIT: if str(op.inputs[1].dtype) == 'int32': op.op.version = 3 elif str(op.inputs[1].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.SPARSE_TO_DENSE: if str(op.inputs[2].dtype) in ('int8', 'uint8'): op.op.version = 3 elif str(op.inputs[2].dtype) == 'int64': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.SLICE: if str(op.inputs[0].dtype).startswith('<U'): op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.UNPACK: if str(op.inputs[0].dtype) == 'int16' or str(op.outputs[0].dtype) == 'int16': op.op.version = 4 elif str(op.inputs[0].dtype) == 'bool': op.op.version = 3 elif str(op.inputs[0].dtype) in ('int8', 'uint8'): op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.DEQUANTIZE: if str(op.inputs[0].dtype) in ('int16', 'float16'): op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.FLOOR_DIV: if str(op.inputs[0].dtype) == 'float32': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.L2_NORMALIZATION: if str(op.outputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.RELU: if str(op.inputs[0].dtype) in ('int8', 'uint8'): op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.GELU: if str(op.inputs[0].dtype) in ('int8', 'uint8'): op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.STRIDED_SLICE: if len(op.inputs[0].shape) > 4: op.op.version = 4 elif str(op.inputs[0].dtype) == 'bool': op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.REVERSE_V2: if str(op.outputs[0].dtype) == 'bool': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.RESIZE_BILINEAR: if op.halfPixelCenters: op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code in (ExtendedOperator.MINIMUM, ExtendedOperator.MAXIMUM): if str(op.inputs[0].dtype) == 'int16' and str(op.outputs[0].dtype) == 'int16': op.op.version = 4 elif ( len(op.inputs[0].shape) != len(op.inputs[1].shape) and max(len(op.inputs[0].shape), len(op.inputs[1].shape)) > 4 ): op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code == ExtendedOperator.PACK: if str(op.inputs[0].dtype) == 'int16' and str(op.outputs[0].dtype) == 'int16': op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.version == ExtendedOperator.TILE: if str(op.inputs[0].dtype).startswith('<U'): op.op.version = 2 else: op.op.version = 1 elif op.op.version in (ExtendedOperator.SPACE_TO_BATCH_ND, ExtendedOperator.BATCH_TO_SPACE_ND): if len(op.inputs[0].shape) != 4: op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.version == ExtendedOperator.SUB: if ( len(op.inputs[0].shape) != len(op.inputs[1].shape) and max(len(op.inputs[0].shape), len(op.inputs[1].shape)) > 4 ): op.op.version = 3 elif str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 elif op.op.code in ( ExtendedOperator.AVERAGE_POOL_2D, ExtendedOperator.ADD, ExtendedOperator.CONCATENATION, ExtendedOperator.MAX_POOL_2D, ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.SOFTMAX, ExtendedOperator.SPACE_TO_DEPTH, ExtendedOperator.SPLIT_V, ExtendedOperator.MEAN, ExtendedOperator.SUM, ExtendedOperator.REDUCE_MAX, ExtendedOperator.REDUCE_MIN, ExtendedOperator.RELU6, ExtendedOperator.RESIZE_NEAREST_NEIGHBOR, ExtendedOperator.TANH, ExtendedOperator.LOGISTIC, ExtendedOperator.LOG_SOFTMAX, ExtendedOperator.TOPK_V2, ExtendedOperator.ARG_MAX, ExtendedOperator.ARG_MIN, ExtendedOperator.EQUAL, ExtendedOperator.NOT_EQUAL, ExtendedOperator.GREATER, ExtendedOperator.GREATER_EQUAL, ExtendedOperator.LESS, ExtendedOperator.LESS_EQUAL, ExtendedOperator.SELECT, ): if str(op.inputs[0].dtype) == 'int8': op.op.version = 2 else: op.op.version = 1 else: op.op.version = 1