tinynn/converter/operators/tflite/base.py (378 lines of code) (raw):

import copy import typing import flatbuffers import numpy as np import torch from ...schemas.tflite import schema_generated as tflite from tinynn.util.util import get_logger log = get_logger(__name__) Offset = int class OpCode(object): code: int version: int index: int tfl_opcode: Offset def __init__(self, code: int, version: int, custom_code: typing.Optional[str] = None): self.code = code self.version = version self.index = 0 self.custom_code = custom_code self.tfl_opcode = 0 def build(self, builder: flatbuffers.Builder) -> Offset: custom_code = None if self.custom_code is not None: custom_code = create_string(builder, tflite.OperatorCode.CustomCode, self.custom_code) tflite.OperatorCodeStart(builder) if self.code < tflite.BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES: tflite.OperatorCodeAddDeprecatedBuiltinCode(builder, self.code) tflite.OperatorCodeAddBuiltinCode(builder, self.code) tflite.OperatorCodeAddVersion(builder, self.version) if custom_code is not None: tflite.OperatorCodeAddCustomCode(builder, custom_code) self.tfl_opcode = tflite.OperatorCodeEnd(builder) return self.tfl_opcode class BaseOperator(object): inputs: typing.List['Tensor'] outputs: typing.List['Tensor'] intermediates: typing.List['Tensor'] op: OpCode tfl_op: Offset tfl_inputs_idx: typing.Iterable[int] tfl_outputs_idx: typing.Iterable[int] extra_hints: typing.Dict[str, typing.Any] def __init__(self, op: int, inputs: typing.List['Tensor'], outputs: typing.List['Tensor'], op_version: int = 1): self.inputs = inputs self.outputs = outputs self.intermediates = [] self.op = OpCode(op, op_version) self.tfl_op = 0 self.tfl_inputs_idx = [] self.tfl_outputs_idx = [] self.tfl_intermediates_idx = [] self.extra_hints = {} def build(self, builder: flatbuffers.Builder) -> Offset: tfl_inputs_idx = create_numpy_array(builder, tflite.Operator.Inputs, self.tfl_inputs_idx) tfl_outputs_idx = create_numpy_array(builder, tflite.Operator.Outputs, self.tfl_outputs_idx) tfl_intermediates_idx = create_numpy_array(builder, tflite.Operator.Intermediates, self.tfl_intermediates_idx) tflite.OperatorStart(builder) tflite.OperatorAddOpcodeIndex(builder, self.op.index) tflite.OperatorAddInputs(builder, tfl_inputs_idx) tflite.OperatorAddOutputs(builder, tfl_outputs_idx) tflite.OperatorAddIntermediates(builder, tfl_intermediates_idx) self.tfl_op = tflite.OperatorEnd(builder) return self.tfl_op def type_name(self) -> str: return type(self).__name__.replace('Operator', '') class QuantizationParameters: scale: typing.Union[float, typing.List[float]] zero_point: typing.Union[int, typing.List[int]] tfl_quant_args: Offset def __init__( self, scale: typing.Union[float, typing.List[float]], zero_point: int, dim: typing.Optional[int] = None ): self.scale = scale self.zero_point = zero_point self.dim = dim self.tfl_quant_args = 0 def build(self, builder: flatbuffers.Builder) -> Offset: if isinstance(self.scale, float): scale = create_numpy_array(builder, tflite.QuantizationParameters.Scale, [self.scale], 'float32') else: scale = create_numpy_array(builder, tflite.QuantizationParameters.Scale, self.scale, 'float32') if isinstance(self.zero_point, int): zero_point = create_numpy_array( builder, tflite.QuantizationParameters.ZeroPoint, [self.zero_point], 'int64' ) else: zero_point = create_numpy_array(builder, tflite.QuantizationParameters.ZeroPoint, self.zero_point, 'int64') tflite.QuantizationParametersStart(builder) tflite.QuantizationParametersAddMin(builder, 0) tflite.QuantizationParametersAddMax(builder, 0) tflite.QuantizationParametersAddScale(builder, scale) tflite.QuantizationParametersAddZeroPoint(builder, zero_point) if self.dim is not None: tflite.QuantizationParametersAddQuantizedDimension(builder, self.dim) self.tfl_quant_args = tflite.QuantizationParametersEnd(builder) return self.tfl_quant_args def __repr__(self) -> str: return f'scale={self.scale}, zero_point={self.zero_point}' class Buffer(object): data: typing.Union[bytearray, bytes] index: int tfl_buffer: Offset def __init__(self, data: typing.Union[bytearray, bytes]): self.data = data self.index = 0 self.tfl_buffer = 0 def build(self, builder: flatbuffers.Builder) -> Offset: if len(self.data) != 0: data = create_byte_array(builder, tflite.Buffer.Data, self.data) else: data = 0 tflite.BufferStart(builder) tflite.BufferAddData(builder, data) self.tfl_buffer = tflite.BufferEnd(builder) return self.tfl_buffer class FakeQuantTensor(object): def __init__(self, tensor, scale, zero_point, dim=None) -> None: self.tensor = tensor self.scale = scale self.zero_point = zero_point self.dim = dim class Tensor(object): tensor: np.ndarray name: str quantization: typing.Optional[QuantizationParameters] buffer: typing.Optional[Buffer] dtype: np.dtype shape: typing.Iterable[int] tfl_tensor: int def __init__( self, tensor: typing.Iterable, name: str, quantization: QuantizationParameters = None, has_buffer: bool = True, dtype: str = None, is_variable: bool = False, asymmetric: bool = True, q_type: type = np.uint8, ): self.quantization = None self.name = name self.index = 0 self.is_variable = is_variable if type(tensor) is FakeQuantTensor: self.quantization = QuantizationParameters(tensor.scale, tensor.zero_point, tensor.dim) tensor = tensor.tensor if isinstance(tensor, torch.nn.Parameter): tensor = tensor.data if type(tensor).__module__ == 'numpy': self.tensor = tensor elif type(tensor) is torch.Tensor: assert tensor.is_contiguous, "Tensor should be contiguous" if tensor.dtype == torch.quint8: self.tensor = torch.int_repr(tensor.detach()).numpy() if q_type == np.uint8: self.quantization = QuantizationParameters(tensor.q_scale(), tensor.q_zero_point()) else: if not asymmetric: sym_u8_offset = 128 if tensor.q_zero_point() != sym_u8_offset: log.warning( "As for symmetric quantization, the zero point of the u8 tensors should be" f" {sym_u8_offset}, but got {tensor.q_zero_point()}. This could happen if you didn't" " train the model after QAT preparation, or the OP is not supported in symmetric" " quantization (e.g. sigmoid)" ) else: sym_u8_offset = tensor.q_zero_point() scale = tensor.q_scale() self.tensor = (self.tensor.astype(np.int32) - 128).astype(np.int8) if q_type == np.int16: scale = scale * 256 / 65536 self.tensor = np.round(self.tensor.astype(np.float32) / 256 * 65536).astype(np.int16) self.quantization = QuantizationParameters(scale, sym_u8_offset - 128) elif tensor.dtype == torch.qint8: self.tensor = torch.int_repr(tensor.detach()).numpy() if q_type == np.uint8: if asymmetric: asym_s8_offset = 0 assert tensor.q_zero_point() == asym_s8_offset, ( "As for asymmetric quantization, the zero point of the s8 tensors should be" f" {asym_s8_offset}, but got {tensor.q_zero_point()}. " ) else: asym_s8_offset = tensor.q_zero_point() self.tensor = self.tensor.view(np.uint8) + 128 self.quantization = QuantizationParameters(tensor.q_scale(), asym_s8_offset + 128) else: if tensor.qscheme() in (torch.per_tensor_symmetric, torch.per_tensor_affine): self.quantization = QuantizationParameters(tensor.q_scale(), tensor.q_zero_point()) else: assert tensor.qscheme() in (torch.per_channel_symmetric, torch.per_channel_affine) scales = tensor.q_per_channel_scales().tolist() zero_points = tensor.q_per_channel_zero_points().tolist() dim = tensor.q_per_channel_axis() if dim < 0: dim += tensor.dim() assert all((t == 0 for t in zero_points)), ( 'As for per-channel quantization, " "the zero point of the s8' f' tensors should be 0, but got ${zero_points}' ) self.quantization = QuantizationParameters(scales, zero_points, dim) else: self.tensor = tensor.detach().numpy() elif type(tensor) is torch.Size: self.tensor = np.asarray(tensor, dtype='int32') elif type(tensor) in (tuple, list): self.tensor = np.asarray(tensor, dtype=dtype) else: assert False, f"unrecognized tensor type {type(tensor).__name__}" if has_buffer: self.buffer = Buffer(self.tensor.tobytes()) else: self.buffer = None self.dtype = self.tensor.dtype self.shape = self.tensor.shape if quantization is not None: self.quantization = copy.deepcopy(quantization) self.tfl_tensor = 0 def __repr__(self) -> str: return f'{self.name}: {self.dtype}{self.shape}' def reinterpret_as(self, new_type: typing.Union[type, np.dtype]): self.tensor = self.tensor.view(new_type) self.dtype = self.tensor.dtype def build(self, builder: flatbuffers.Builder) -> Offset: name = create_string(builder, tflite.Tensor.Name, self.name) shape = create_numpy_array(builder, tflite.Tensor.Shape, self.shape) dtype = numpy_tflite_dtype_mappings[str(self.dtype)] buffer = 0 if self.buffer is not None: buffer = self.buffer.index quantization = 0 if self.quantization is not None: quantization = self.quantization.build(builder) tflite.TensorStart(builder) tflite.TensorAddBuffer(builder, buffer) tflite.TensorAddIsVariable(builder, self.is_variable) tflite.TensorAddName(builder, name) tflite.TensorAddShape(builder, shape) tflite.TensorAddType(builder, dtype) tflite.TensorAddQuantization(builder, quantization) self.tfl_tensor = tflite.TensorEnd(builder) return self.tfl_tensor class OptionalTensor(Tensor): def __init__(self): self.index = -1 self.quantization = None self.name = '__tinynn_optional_tensor__' self.is_variable = False self.tensor = None self.shape = None self.dtype = None self.buffer = None def __repr__(self) -> str: return 'OptionalTensor' def build(self, builder: flatbuffers.Builder) -> Offset: raise Exception('Could not build an optional tensor') class SubGraph(object): tensors: typing.List[Offset] inputs: typing.List[int] outputs: typing.List[int] operators: typing.List[Offset] tfl_subgraph: int def __init__(self): self.tensors = [] self.inputs = [] self.outputs = [] self.operators = [] self.tfl_subgraph = 0 def build(self, builder: flatbuffers.Builder) -> Offset: inputs = create_numpy_array(builder, tflite.SubGraph.Inputs, self.inputs) outputs = create_numpy_array(builder, tflite.SubGraph.Outputs, self.outputs) operators = create_offset_vector(builder, tflite.SubGraph.Operators, self.operators) tensors = create_offset_vector(builder, tflite.SubGraph.Tensors, self.tensors) name = create_string(builder, tflite.SubGraph.Name, "main_graph") tflite.SubGraphStart(builder) tflite.SubGraphAddInputs(builder, inputs) tflite.SubGraphAddOutputs(builder, outputs) tflite.SubGraphAddName(builder, name) tflite.SubGraphAddTensors(builder, tensors) tflite.SubGraphAddOperators(builder, operators) self.tfl_subgraph = tflite.SubGraphEnd(builder) return self.tfl_subgraph class Model(object): buffers: typing.List[Offset] opcodes: typing.List[Offset] subgraphs: typing.List[Offset] tfl_model: Offset def __init__(self): self.buffers = [] self.opcodes = [] self.subgraphs = [] self.tfl_model = 0 def build(self, builder: flatbuffers.Builder) -> Offset: buffers = create_offset_vector(builder, tflite.Model.Buffers, self.buffers) opcodes = create_offset_vector(builder, tflite.Model.OperatorCodes, self.opcodes) subgraphs = create_offset_vector(builder, tflite.Model.Subgraphs, self.subgraphs) description = create_string(builder, tflite.Model.Description, "TinyNeuralNetwork Converted.") version = 3 tflite.ModelStart(builder) tflite.ModelAddBuffers(builder, buffers) tflite.ModelAddDescription(builder, description) tflite.ModelAddVersion(builder, version) tflite.ModelAddOperatorCodes(builder, opcodes) tflite.ModelAddSubgraphs(builder, subgraphs) self.tfl_model = tflite.ModelEnd(builder) return self.tfl_model def create_offset_vector(builder: flatbuffers.Builder, prop: typing.Callable, vec: typing.Iterable): if type(vec) not in (tuple, list): assert False, "type of vec unexpected, expected: list or tuple" elif type(vec) is tuple: vec = list(vec) prop_name = prop.__name__ cls_name = prop.__qualname__.split('.')[0] func_name = f'{cls_name}Start{prop_name}Vector' if not hasattr(tflite, func_name): assert False, f"invalid prop is given, {prop.__qualname__}" start_vec_func = getattr(tflite, func_name) start_vec_func(builder, len(vec)) for item in reversed(vec): builder.PrependUOffsetTRelative(item) try: end = builder.EndVector(len(vec)) except TypeError: end = builder.EndVector() return end def create_numpy_array(builder: flatbuffers.Builder, prop: typing.Callable, vec: typing.Iterable, dtype: str = 'int32'): if type(vec) not in (tuple, list, torch.Size) and type(vec).__module__ != 'numpy': assert False, "type of vec unexpected, expected: list or tuple or ndarray" prop_name = prop.__name__ cls_name = prop.__qualname__.split('.')[0] func_name = f'{cls_name}Start{prop_name}Vector' if not hasattr(tflite, func_name): assert False, f"invalid prop is given, {prop.__qualname__}" arr = np.asarray(vec, dtype=dtype) return builder.CreateNumpyVector(arr) def create_string(builder: flatbuffers.Builder, prop: typing.Callable, val: str): if type(val) is not str: assert False, "type of val unexpected, expected: str" prop_name = prop.__name__ cls_name = prop.__qualname__.split('.')[0] func_name = f'{cls_name}Add{prop_name}' if not hasattr(tflite, func_name): assert False, f"invalid prop is given, {prop.__qualname__}" return builder.CreateString(val) def create_byte_array(builder: flatbuffers.Builder, prop: typing.Callable, val: typing.Union[bytes, bytearray]): if type(val) not in (bytearray, bytes): assert False, "type of val unexpected, expected: bytes or bytearray" prop_name = prop.__name__ cls_name = prop.__qualname__.split('.')[0] func_name = f'{cls_name}Start{prop_name}Vector' if not hasattr(tflite, func_name): assert False, f"invalid prop is given, {prop.__qualname__}" return builder.CreateByteVector(val) numpy_tflite_dtype_mappings = { 'bool': tflite.TensorType.BOOL, 'int16': tflite.TensorType.INT16, 'int32': tflite.TensorType.INT32, 'int64': tflite.TensorType.INT64, 'int8': tflite.TensorType.INT8, 'uint8': tflite.TensorType.UINT8, 'float16': tflite.TensorType.FLOAT16, 'float32': tflite.TensorType.FLOAT32, 'float64': tflite.TensorType.FLOAT64, } torch_tflite_dtype_mappings = { torch.bool: tflite.TensorType.BOOL, torch.int16: tflite.TensorType.INT16, torch.int32: tflite.TensorType.INT32, torch.int64: tflite.TensorType.INT64, torch.qint8: tflite.TensorType.INT8, torch.quint8: tflite.TensorType.UINT8, torch.float16: tflite.TensorType.FLOAT16, torch.float32: tflite.TensorType.FLOAT32, torch.float64: tflite.TensorType.FLOAT64, } OptionalTensorInstance = OptionalTensor()