def elementwise_op_quantize_passthrough_pass()

in tinynn/converter/operators/optimize.py [0:0]


    def elementwise_op_quantize_passthrough_pass(self):
        edges = self.graph.graph.es.select(
            functools.partial(
                is_quantize_elementwise_op_edge, graph_converter=self.graph.graph, with_lstm=self.hybrid_int16_lstm
            )
        )
        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.DEQUANTIZE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        for node in unique_nodes:
            op = node['op']
            input_indices = op_input_indices(op)

            prev_nodes = []
            q_tensors = dict()
            prev_output_indices = []
            skip_names = []
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.DEQUANTIZE:
                    q_tensors[prev_node_name] = prev_node['op'].inputs[0]

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    if (
                        node['node_type'] in (ExtendedOperator.MINIMUM, ExtendedOperator.MAXIMUM)
                        and i != 0
                        and prev_node_name not in self.graph.q_mapping
                    ):
                        f_tensor = self.graph.tensor_map[prev_node_name]
                        r_tensor = q_tensors[op.inputs[0].name]
                        q_arr = np.rint(
                            f_tensor.tensor / r_tensor.quantization.scale + r_tensor.quantization.zero_point
                        )
                        i_type = np.iinfo(r_tensor.tensor.dtype)

                        if np.any(q_arr > i_type.max):
                            warnings.warn('Overflow while quantizing the tensor')
                            q_arr = np.minimum(q_arr, i_type.max)

                        if np.any(q_arr < i_type.min):
                            warnings.warn('Underflow while quantizing the tensor')
                            q_arr = np.maximum(q_arr, i_type.min)

                        q_arr = q_arr.astype(r_tensor.dtype)
                        q_tensor = self.create_attr_tensor(q_arr, quantization=r_tensor.quantization)
                        self.graph.q_mapping[prev_node_name] = q_tensor

                    if prev_node_name in self.graph.q_mapping:
                        skip_names.append(prev_node_name)

            next_nodes = []
            next_edges = []
            out_nodes = []
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                else:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.QUANTIZE:
                    skip = False
                    name = next_node['op'].inputs[0].name
                    q_tensor = next_node['op'].outputs[0]
                    assert q_tensor.quantization is not None
                    if node['node_type'] in (
                        ExtendedOperator.BATCH_MATMUL,
                        ExtendedOperator.ABS,
                        ExtendedOperator.RSQRT,
                    ):
                        if q_tensor.dtype not in (np.dtype('int8'), np.dtype('int16')):
                            skip = True
                    elif node['node_type'] == ExtendedOperator.DIV:
                        if q_tensor.dtype != np.dtype('uint8'):
                            skip = True
                    elif node['node_type'] == ExtendedOperator.SOFTMAX:
                        if q_tensor.dtype == np.dtype('int8'):
                            if (
                                abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
                                or q_tensor.quantization.zero_point != -128
                            ):
                                skip = True
                        elif q_tensor.dtype == np.dtype('int16'):
                            if (
                                abs(q_tensor.quantization.scale - 1.0 / 32768) > 0.001 * 1.0 / 32768
                                or q_tensor.quantization.zero_point != 0
                            ):
                                skip = True
                        elif q_tensor.dtype == np.dtype('uint8'):
                            if (
                                abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
                                or q_tensor.quantization.zero_point != 0
                            ):
                                log.warning(
                                    'On some chips, only softmax with scale=1.0/256 and zero_point=0 is supported'
                                )
                        else:
                            skip = True
                    elif node['node_type'] == ExtendedOperator.LOG_SOFTMAX:
                        if q_tensor.dtype == np.dtype('int8'):
                            if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 127:
                                skip = True
                        elif q_tensor.dtype == np.dtype('uint8'):
                            if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 255:
                                skip = True
                        else:
                            skip = True

                    if not skip:
                        q_tensors[name] = q_tensor

            cur_transpose_size = len(q_tensors)
            new_transpose_size = len(prev_nodes) + len(next_nodes) - len(skip_names)

            # Skip if the number of [de]quantize nodes is not decreasing
            if len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
                continue

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    if prev_out.name in skip_names:
                        prev_new_out = self.graph.q_mapping[prev_out.name]
                        self.graph.add_nodes([prev_new_out])
                        tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                        actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
                    else:
                        prev_new_out = self.create_transform_tensor(
                            q_tensors[prev_out.name].tensor, quantization=q_tensors[prev_out.name].quantization
                        )
                        tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                        self.graph.add_operator(tfl.QuantizeOperator([prev_out], [prev_new_out]))
                        actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                new_out = self.create_transform_tensor(
                    q_tensors[op_out.name].tensor, quantization=q_tensors[op_out.name].quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                self.graph.add_operator(tfl.DequantizeOperator([new_out], [op_out]))

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)