def rewrite_quantize_graph_for_tensorrt()

in tinynn/graph/quantization/quantizer.py [0:0]


    def rewrite_quantize_graph_for_tensorrt(self, graph: TraceGraph) -> None:
        """Rewrites the computation graph for TensorRT quantization"""

        processed_types = set()
        for types in TENSORRT_OBSERVED_NODES:
            if isinstance(types, tuple):
                if hasattr(types[0], types[1]):
                    processed_types.add(getattr(types[0], types[1]))
            else:
                processed_types.add(types)

        def _is_observed_nodes_for_tensorrt(node, custom_data):
            return node.kind() in processed_types

        observed_nodes = graph.filter_forward_nodes(_is_observed_nodes_for_tensorrt)
        node_map = dict()
        for idx, node in enumerate(observed_nodes):
            fake_quant_cls = torch_q.QuantStub
            assert node.rev_index is False
            prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values()
            for inner_idx, prev_node in enumerate(prev_nodes):

                if prev_node.kind() in ('shape', 'device', 'size', 'dtype'):
                    continue

                prev_tensor_ptrs = []
                for pt in node.prev_tensors:
                    for nt in prev_node.next_tensors:
                        if isinstance(nt, (list, tuple)):
                            for k, ntt in enumerate(nt):
                                if id(pt) == id(ntt):
                                    if id(pt) not in prev_tensor_ptrs:
                                        prev_tensor_ptrs.append(id(pt))
                                    break
                        elif id(pt) == id(nt):
                            if id(pt) not in prev_tensor_ptrs:
                                prev_tensor_ptrs.append(id(pt))
                            break

                for ptr_idx, ptr in enumerate(prev_tensor_ptrs):
                    if ptr in node_map:
                        fake_quant = node_map[ptr]

                        graph.insert_between(prev_node, node, fake_quant, move_idx=True, tensor_ptrs=set([ptr]))
                    else:

                        fake_quant = fake_quant_cls()

                        fake_quant_name = f'fake_quant_inner_{idx}_{inner_idx}_{ptr_idx}'

                        graph.module_unique_name_dict[id(fake_quant)] = fake_quant_name
                        graph.module_original_name_dict[id(fake_quant)] = fake_quant_name

                        module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'

                        graph.insert_between(prev_node, node, fake_quant, move_idx=True, tensor_ptrs=set([ptr]))
                        node_map[ptr] = graph.nodes_map[fake_quant_name]

        graph.quantized = True
        graph.recompute_forward_order()