tinynn/converter/operators/tflite/custom.py (86 lines of code) (raw):

from .generated_ops import CustomOperator from ...schemas.tflite import schema_generated as tflite HAS_FLEXBUFFER = False try: from flatbuffers import flexbuffers HAS_FLEXBUFFER = True except ImportError: pass class TFLiteDetectionPostprocessOperator(CustomOperator): def __init__( self, inputs, outputs, max_detections: int, max_classes_per_detection: int, nms_score_threshold: float, nms_iou_threshold: float, num_classes: int, y_scale: float, x_scale: float, h_scale: float, w_scale: float, ) -> None: super().__init__(inputs, outputs) assert HAS_FLEXBUFFER, "TFLITE_DETECTION_POSTPROCESS relies on FlexBuffer, which requires flatbuffers>=2" self.op.custom_code = "TFLITE_DETECTION_POSTPROCESS" self.max_detections = max_detections self.max_classes_per_detection = max_classes_per_detection self.nms_score_threshold = nms_score_threshold self.nms_iou_threshold = nms_iou_threshold self.num_classes = num_classes self.y_scale = y_scale self.x_scale = x_scale self.h_scale = h_scale self.w_scale = w_scale def build(self, builder): fbb = flexbuffers.Builder() with fbb.Map(): fbb.Int('max_detections', self.max_detections) fbb.Int('max_classes_per_detection', self.max_classes_per_detection) fbb.Float('nms_score_threshold', self.nms_score_threshold) fbb.Float('nms_iou_threshold', self.nms_iou_threshold) fbb.Int('num_classes', self.num_classes) fbb.Float('y_scale', self.y_scale) fbb.Float('x_scale', self.x_scale) fbb.Float('h_scale', self.h_scale) fbb.Float('w_scale', self.w_scale) self.custom_options = fbb.Finish() return super().build(builder) class MTKTransposeConvOperator(CustomOperator): def __init__( self, inputs, outputs, activation: int = tflite.ActivationFunctionType.NONE, depth_multiplier: int = 0, dilation_height_factor: int = 0, dilation_width_factor: int = 0, padding_type: int = tflite.Padding.SAME, stride_height: int = 0, stride_width: int = 0, ) -> None: super().__init__(inputs, outputs) self.op.custom_code = "MTK_TRANSPOSE_CONV" self.activation = activation self.depth_multiplier = depth_multiplier self.dilation_height_factor = dilation_height_factor self.dilation_width_factor = dilation_width_factor self.padding_type = padding_type self.stride_height = stride_height self.stride_width = stride_width def build(self, builder): fbb = flexbuffers.Builder() fbb.MapFromElements( { 'activation': self.activation, 'depth_multiplier': self.depth_multiplier, 'dilation_height_factor': self.dilation_height_factor, 'dilation_width_factor': self.dilation_width_factor, 'PaddingType': self.padding_type, 'stride_height': self.stride_height, 'stride_width': self.stride_width, } ) self.custom_options = fbb.Finish() return super().build(builder)