in tinynn/graph/quantization/quantizer.py [0:0]
def rewrite_quantize_graph_for_tensorrt(self, graph: TraceGraph) -> None:
"""Rewrites the computation graph for TensorRT quantization"""
processed_types = set()
for types in TENSORRT_OBSERVED_NODES:
if isinstance(types, tuple):
if hasattr(types[0], types[1]):
processed_types.add(getattr(types[0], types[1]))
else:
processed_types.add(types)
def _is_observed_nodes_for_tensorrt(node, custom_data):
return node.kind() in processed_types
observed_nodes = graph.filter_forward_nodes(_is_observed_nodes_for_tensorrt)
node_map = dict()
for idx, node in enumerate(observed_nodes):
fake_quant_cls = torch_q.QuantStub
assert node.rev_index is False
prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values()
for inner_idx, prev_node in enumerate(prev_nodes):
if prev_node.kind() in ('shape', 'device', 'size', 'dtype'):
continue
prev_tensor_ptrs = []
for pt in node.prev_tensors:
for nt in prev_node.next_tensors:
if isinstance(nt, (list, tuple)):
for k, ntt in enumerate(nt):
if id(pt) == id(ntt):
if id(pt) not in prev_tensor_ptrs:
prev_tensor_ptrs.append(id(pt))
break
elif id(pt) == id(nt):
if id(pt) not in prev_tensor_ptrs:
prev_tensor_ptrs.append(id(pt))
break
for ptr_idx, ptr in enumerate(prev_tensor_ptrs):
if ptr in node_map:
fake_quant = node_map[ptr]
graph.insert_between(prev_node, node, fake_quant, move_idx=True, tensor_ptrs=set([ptr]))
else:
fake_quant = fake_quant_cls()
fake_quant_name = f'fake_quant_inner_{idx}_{inner_idx}_{ptr_idx}'
graph.module_unique_name_dict[id(fake_quant)] = fake_quant_name
graph.module_original_name_dict[id(fake_quant)] = fake_quant_name
module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'
graph.insert_between(prev_node, node, fake_quant, move_idx=True, tensor_ptrs=set([ptr]))
node_map[ptr] = graph.nodes_map[fake_quant_name]
graph.quantized = True
graph.recompute_forward_order()