def rewrite_quantize_graph()

in tinynn/graph/quantization/quantizer.py [0:0]


    def rewrite_quantize_graph(self, graph: TraceGraph) -> None:
        """Rewrites the computation graph for quantization"""
        if graph.quantized:
            return

        graph_quantized = False
        for n in graph.forward_nodes:
            if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
                graph_quantized = True
                break

        for n in graph.other_init_nodes:
            if n.type() is nnq.FloatFunctional:
                graph_quantized = True
                break

        if not graph_quantized:
            log.warning(
                'Graph is not quantized. No need to rewrite. Please pass in `config={"rewrite_graph": False}` to'
                ' suppress this warning'
            )
            return

        # Remove QuantStubs and DeQuantStubs
        names = [n.unique_name for n in graph.forward_nodes]
        for name in names:
            n = graph.nodes_map[name]
            if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
                graph.remove_node(n)

        def _is_add_relu_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind == 'add_relu'
                    and len(node.prev_nodes) > 1
                    and node.prev_nodes[0].type() is nnq.FloatFunctional
                )

        # Split FloatFunctional.add_relu to FloatFunctional.add and torch.relu
        add_relu_nodes = graph.filter_forward_nodes(_is_add_relu_node)
        for n in add_relu_nodes:
            n.module.kind = 'add'
            n.module.func_type = 'add'

            parts = n.module.full_name.split('.')[:-1] + [n.module.func_type]
            n.module.full_name = '.'.join(parts)

            with override_current_trace_graph(graph):
                next_func = TraceFunction('torch.relu').parse_args(n.next_tensors[0])

            next_out = torch.relu(n.next_tensors[0])
            graph.insert_after(n, next_func, [next_out])

        # Replace FloatFunctional.{add_scalar, mul_scalar} with torch.{add, mul}
        def _is_add_mul_scalar_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind in ('add_scalar', 'mul_scalar')
                    and len(node.prev_nodes) > 1
                    and node.prev_nodes[0].type() is nnq.FloatFunctional
                )

        add_mul_scalar_nodes = graph.filter_forward_nodes(_is_add_mul_scalar_node)
        for n in add_mul_scalar_nodes:
            n.module.kind = n.module.kind.split('_')[0]
            n.module.func_type = n.module.kind

            parts = n.module.full_name.split('.')[:-1] + [n.module.func_type]
            n.module.full_name = '.'.join(parts)

        # Replace FloatFunctional.{add, mul, cat} with torch.{add, mul, cat}
        def _is_add_mul_cat_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind in ('add', 'mul', 'cat')
                    and len(node.prev_nodes) > 1
                    and node.prev_nodes[0].type() is nnq.FloatFunctional
                )

        add_mul_cat_nodes = graph.filter_forward_nodes(_is_add_mul_cat_node)
        for n in add_mul_cat_nodes:
            parts = ['torch'] + [n.module.func_type]
            n.module.full_name = '.'.join(parts)
            n.module.is_class = False

            n.prev_nodes[0].next_nodes.remove(n)
            n.prev_nodes.pop(0)
            n.prev_indices.pop(0)

            n.module.tensor_names.pop(0)
            n.module.args_parsed.pop(0)
            n.module.args_to_string(n.module.args_parsed)

        # Remove FloatFunctional nodes
        names = [n.unique_name for n in graph.other_init_nodes]
        for name in names:
            n = graph.nodes_map[name]
            if n.type() is nnq.FloatFunctional:
                graph.other_init_nodes.remove(n)
                del graph.nodes_map[n.unique_name]

        graph.quantized = False
        graph.recompute_forward_order()