in tinynn/graph/quantization/quantizer.py [0:0]
def rewrite_quantize_graph(self, graph: TraceGraph) -> None:
"""Rewrites the computation graph for quantization"""
if graph.quantized:
return
graph_quantized = False
for n in graph.forward_nodes:
if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
graph_quantized = True
break
for n in graph.other_init_nodes:
if n.type() is nnq.FloatFunctional:
graph_quantized = True
break
if not graph_quantized:
log.warning(
'Graph is not quantized. No need to rewrite. Please pass in `config={"rewrite_graph": False}` to'
' suppress this warning'
)
return
# Remove QuantStubs and DeQuantStubs
names = [n.unique_name for n in graph.forward_nodes]
for name in names:
n = graph.nodes_map[name]
if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
graph.remove_node(n)
def _is_add_relu_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind == 'add_relu'
and len(node.prev_nodes) > 1
and node.prev_nodes[0].type() is nnq.FloatFunctional
)
# Split FloatFunctional.add_relu to FloatFunctional.add and torch.relu
add_relu_nodes = graph.filter_forward_nodes(_is_add_relu_node)
for n in add_relu_nodes:
n.module.kind = 'add'
n.module.func_type = 'add'
parts = n.module.full_name.split('.')[:-1] + [n.module.func_type]
n.module.full_name = '.'.join(parts)
with override_current_trace_graph(graph):
next_func = TraceFunction('torch.relu').parse_args(n.next_tensors[0])
next_out = torch.relu(n.next_tensors[0])
graph.insert_after(n, next_func, [next_out])
# Replace FloatFunctional.{add_scalar, mul_scalar} with torch.{add, mul}
def _is_add_mul_scalar_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind in ('add_scalar', 'mul_scalar')
and len(node.prev_nodes) > 1
and node.prev_nodes[0].type() is nnq.FloatFunctional
)
add_mul_scalar_nodes = graph.filter_forward_nodes(_is_add_mul_scalar_node)
for n in add_mul_scalar_nodes:
n.module.kind = n.module.kind.split('_')[0]
n.module.func_type = n.module.kind
parts = n.module.full_name.split('.')[:-1] + [n.module.func_type]
n.module.full_name = '.'.join(parts)
# Replace FloatFunctional.{add, mul, cat} with torch.{add, mul, cat}
def _is_add_mul_cat_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind in ('add', 'mul', 'cat')
and len(node.prev_nodes) > 1
and node.prev_nodes[0].type() is nnq.FloatFunctional
)
add_mul_cat_nodes = graph.filter_forward_nodes(_is_add_mul_cat_node)
for n in add_mul_cat_nodes:
parts = ['torch'] + [n.module.func_type]
n.module.full_name = '.'.join(parts)
n.module.is_class = False
n.prev_nodes[0].next_nodes.remove(n)
n.prev_nodes.pop(0)
n.prev_indices.pop(0)
n.module.tensor_names.pop(0)
n.module.args_parsed.pop(0)
n.module.args_to_string(n.module.args_parsed)
# Remove FloatFunctional nodes
names = [n.unique_name for n in graph.other_init_nodes]
for name in names:
n = graph.nodes_map[name]
if n.type() is nnq.FloatFunctional:
graph.other_init_nodes.remove(n)
del graph.nodes_map[n.unique_name]
graph.quantized = False
graph.recompute_forward_order()