def fuse_requantize()

in tinynn/converter/operators/optimize.py [0:0]


    def fuse_requantize(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(
            functools.partial(is_requantize_fusable_edge, graph_converter=self.graph.graph)
        )
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        for pre_activ, activ, tensor in filtered_pairs:
            if pre_activ.outdegree() > 1:
                skip = False
                pre_quantize = None
                for out_edge in pre_activ.out_edges():
                    next_node = self.graph.graph.vs[out_edge.target]
                    while True:
                        if next_node['node_type'] == ExtendedOperator.QUANTIZE:
                            if pre_quantize is None:
                                pre_quantize = next_node['op'].outputs[0].quantization
                            else:
                                cur_quantize = next_node['op'].outputs[0].quantization
                                if (
                                    pre_quantize.scale != cur_quantize.scale
                                    or pre_quantize.zero_point != cur_quantize.zero_point
                                    or pre_quantize.dim != cur_quantize.dim
                                ):
                                    skip = True
                            break
                        elif next_node['node_type'] == ExtendedOperator.DEQUANTIZE:
                            break
                        elif next_node['node_type'] in (ExtendedOperator.RESHAPE, ExtendedOperator.TRANSPOSE):
                            if next_node.outdegree() > 1:
                                skip = True
                                break
                            else:
                                next_node = self.graph.graph.vs[next_node.out_edges()[0].target]
                        else:
                            skip = True
                            break

                    if skip:
                        break

                if skip:
                    continue

                # Find out the output of the first node in the sequence
                output_name = activ['op'].inputs[0].name
                output_idx = pre_activ['outputs'].index(output_name)
                new_output = pre_activ['outputs'][output_idx]
                assert new_output in self.graph.tensor_map

                # For each node that is next of the last node, we connect it with the first node
                # Also, the replace the tensors when needed
                self.graph.replace_next_tensors(activ, pre_activ, new_output)

                new_tensor = pre_activ['op'].outputs[0]
                old_tensor = activ['op'].outputs[0]
                new_tensor.quantization = old_tensor.quantization
            else:
                # Find out the output of the batch-norm nodes
                new_output = activ['outputs'][0]
                assert new_output in self.graph.tensor_map

                # For each node that is next of the activation node, we connect it with the previous node
                self.graph.connect_next_tensors(activ, pre_activ, new_output)

                # Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
                # batch-norm instead
                pre_activ['outputs'][0] = new_output
                pre_activ['op'].outputs[0] = self.graph.tensor_map[new_output]
                self.graph.tensor_node_map[new_output] = pre_activ['name']
                tensor['name'] = activ['outputs'][0]
                tensor['label'] = activ['outputs'][0]

            remove_ids.append(activ.index)

        # Delete activation nodes
        self.graph.graph.delete_vertices(remove_ids)