def quantize_input_output_type_pass()

in tinynn/converter/operators/optimize.py [0:0]


    def quantize_input_output_type_pass(self):
        remove_edges = []
        remove_vertices = []
        for i, name in enumerate(self.graph.inputs):
            if self.fuse_input_indices is not None:
                if i not in self.fuse_input_indices:
                    continue

            node_name = self.graph.tensor_node_map[name]
            node = self.graph.graph.vs.find(name=node_name)
            assert node['node_type'] == ExtendedOperator.INPUT_NODE

            # Update input tensor
            input_tensor = self.graph.tensor_map[node['outputs'][0]]
            input_type = str(input_tensor.dtype)
            if input_type == self.quantize_input_output_type:
                continue

            input_arr = input_tensor.tensor.copy()
            input_quantization = copy.deepcopy(input_tensor.quantization)
            if input_type == 'int8' and self.quantize_input_output_type == 'uint8':
                input_tensor.tensor = (input_tensor.tensor.astype('int32') + 128).astype('uint8')
                input_tensor.quantization.zero_point += 128
                input_tensor.dtype = input_tensor.tensor.dtype
            elif input_type == 'uint8' and self.quantize_input_output_type == 'int8':
                input_tensor.tensor = (input_tensor.tensor.astype('int32') - 128).astype('int8')
                input_tensor.quantization.zero_point -= 128
                input_tensor.dtype = input_tensor.tensor.dtype
            else:
                raise AssertionError(
                    f'Unsupported types: input_type: {input_type}, quantize_input_type:'
                    f' {self.quantize_input_output_type}'
                )

            # Create new quantize op
            requantized = self.create_transform_tensor(input_arr, quantization=input_quantization)
            quantize_op = tfl.QuantizeOperator([input_tensor], [requantized])
            self.graph.add_operator(quantize_op)

            # Get the newly-generated node
            new_node_name = self.graph.tensor_node_map[requantized.name]
            new_node = self.graph.graph.vs.find(name=new_node_name)

            # Connect the quantize op to the graph
            self.graph.replace_next_tensors(node, new_node, requantized.name, [new_node_name])

            # Collect the unused connections
            for edge in node.out_edges():
                target_vertex = edge.target_vertex
                if target_vertex['name'] != new_node_name:
                    remove_edges.append(edge.index)

        output_mapping = {}
        for i, name in enumerate(self.graph.outputs):
            if self.fuse_output_indices is not None:
                if i not in self.fuse_output_indices:
                    continue

            output_tensor = self.graph.tensor_map[name]
            output_type = str(output_tensor.dtype)
            if output_type == self.quantize_input_output_type:
                continue

            node_name = self.graph.tensor_node_map[name]
            node = self.graph.graph.vs.find(name=node_name)

            for edge in node.out_edges():
                next_node = edge.target_vertex

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_vertices.append(next_node.index)

            # Update output tensor
            output_arr = output_tensor.tensor.copy()
            output_quantization = copy.deepcopy(output_tensor.quantization)
            if output_type == 'int8' and self.quantize_input_output_type == 'uint8':
                output_arr = (output_arr.astype('int32') + 128).astype('uint8')
                output_quantization.zero_point += 128
            elif output_type == 'uint8' and self.quantize_input_output_type == 'int8':
                output_arr = (output_arr.astype('int32') - 128).astype('int8')
                output_quantization.zero_point -= 128
            else:
                raise AssertionError(
                    f'Unsupported types: output_type: {output_type}, quantize_input_type:'
                    f' {self.quantize_input_output_type}'
                )

            requantized = self.create_transform_tensor(output_arr, quantization=output_quantization)
            quantize_op = tfl.QuantizeOperator([output_tensor], [requantized])
            self.graph.add_operator(quantize_op)

            output_mapping[name] = requantized.name

        if len(output_mapping) > 0:
            new_outputs = []
            output_names = []
            for name in self.graph.outputs:
                if name in output_mapping:
                    new_outputs.append(output_mapping[name])
                    output_names.append(output_mapping[name])
                else:
                    new_outputs.append(name)

            self.graph.outputs.clear()
            self.graph.outputs.extend(new_outputs)
            self.graph.add_outputs(output_names)

        # Remove the collected edges & vertices
        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)