def rewrite_quantize_graph()

in tinynn/graph/quantization/quantizer.py [0:0]


    def rewrite_quantize_graph(self, graph: TraceGraph) -> None:
        """Rewrites the computation graph for quantization"""
        if graph.quantized:
            return

        graph_quantized = False
        for n in graph.forward_nodes:
            if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
                graph_quantized = True
                break

        for n in graph.other_init_nodes:
            if n.type() is nnq.FloatFunctional:
                graph_quantized = True
                break

        if graph_quantized:
            log.warning(
                'Graph is quantized. No need to rewrite. Please pass in `config={"rewrite_graph": False}` to suppress'
                ' this warning'
            )
            return

        if self.backend == 'tensorrt':
            return self.rewrite_quantize_graph_for_tensorrt(graph)

        if self.layerwise_default is False:
            for n in graph.forward_nodes:
                self.layerwise_config.setdefault(n.unique_name, False)

        creation_func_names = load_creation_func_names()

        def _is_extra_constant_nodes(node, custom_data):
            return node.full_name() in creation_func_names

        extra_constant_nodes = graph.filter_forward_nodes(_is_extra_constant_nodes)

        def _is_int_to_float_nodes(node, custom_data):
            if node.full_name() in creation_func_names:
                return False

            if len(node.prev_nodes) == 1 and len(node.next_nodes) == 1:
                if len(node.prev_tensors) == 1 and len(node.next_tensors) == 1:
                    if node.prev_nodes[0].kind() == 'shape' and node.prev_nodes[0].module.is_property:
                        return False
                    if (
                        isinstance(node.prev_tensors[0], torch.Tensor)
                        and isinstance(node.next_tensors[0], torch.Tensor)
                        and node.prev_tensors[0].dtype in (torch.int32, torch.int64)
                        and torch.is_floating_point(node.next_tensors[0])
                    ):
                        return True
            else:
                return False

        int_to_float_nodes = graph.filter_forward_nodes(_is_int_to_float_nodes)

        # When converting float-tensor to int16/int32/int64, we need to add 'fake_dequant' node before convert-node.
        def _is_float_to_non_float_nodes(node, custom_data):
            if isinstance(node.module, TraceFunction) and node.module.kind in ('shape', 'size'):
                return False
            return (
                len(node.next_tensors) == 1
                and len(node.prev_tensors) > 0
                and isinstance(node.prev_tensors[0], torch.Tensor)
                and isinstance(node.next_tensors[0], torch.Tensor)
                and torch.is_floating_point(node.prev_tensors[0])
                and not torch.is_floating_point(node.next_tensors[0])
            )

        float_to_non_float_nodes = graph.filter_forward_nodes(_is_float_to_non_float_nodes)

        def _is_params_in_module(node, custom_data):
            if len(node.prev_nodes) == 1 and len(node.next_nodes) == 1:
                if len(node.prev_tensors) == 1 and len(node.next_tensors) == 1:
                    if isinstance(node.prev_tensors[0], nn.Module) and not isinstance(
                        node.prev_tensors[0], nnq.FloatFunctional
                    ):
                        return True
            return False

        param_nodes = graph.filter_forward_nodes(_is_params_in_module)
        for node in param_nodes:
            prev_node = node.prev_nodes[0]
            is_known_mod = prev_node.kind().__name__ in (
                'Conv1d',
                'Conv2d',
                'Linear',
                'ConvTranspose1d',
                'ConvTranspose2d',
            )
            prop = node.module.full_name

            if is_known_mod and prop in ('weight', 'bias') and self.layerwise_config.get(prev_node.unique_name, True):
                node.module.is_class = False
                node.module.is_property = False

                node.module.full_name = 'tinynn.graph.quantization.utils.get_parameter'
                node.module.args_template = f'{{}}, "{prop}"'
                node.module.args_template_no_self = f'"{prop}"'
                node.module.args_offset = 0
                node.module.update_args_string()

        # First, we insert the QuantStub nodes for every input/constant node
        for idx, node in reversed(
            list(
                enumerate(
                    graph.input_nodes + graph.constant_nodes + extra_constant_nodes + int_to_float_nodes + param_nodes
                )
            )
        ):
            if node.next_tensors[0].dtype in (torch.int32, torch.int64):
                continue

            fake_quant = torch_q.QuantStub()

            graph.module_unique_name_dict[id(fake_quant)] = f'fake_quant_{idx}'
            graph.module_original_name_dict[id(fake_quant)] = f'fake_quant_{idx}'

            fake_quant_cls = type(fake_quant)
            module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'

            graph.insert_after(node, fake_quant)

        # Second, we insert the DeQuantStub nodes for every output node
        for idx, node in enumerate(graph.output_nodes + float_to_non_float_nodes):
            fake_dequant_cls = torch_q.DeQuantStub
            if node.rev_index:
                modules = []
                for rev_idx in range(len(node.prev_nodes)):
                    if node.prev_tensors[rev_idx].dtype in (torch.int32, torch.int64):
                        fake_dequant = nn.Identity()
                    else:
                        fake_dequant = fake_dequant_cls()

                    graph.module_unique_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}_{rev_idx}'
                    graph.module_original_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}_{rev_idx}'

                    module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'
                    modules.append(fake_dequant)

                graph.insert_before(node, modules, move_idx=True)
            else:
                if node.prev_tensors[0].dtype in (torch.int32, torch.int64):
                    continue

                fake_dequant = fake_dequant_cls()

                graph.module_unique_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}'
                graph.module_original_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}'

                module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'

                if len(node.prev_nodes) > 1:
                    # Insert 'fake_dequant' node before type conversion operators
                    graph.insert_between(node.prev_nodes[0], node, fake_dequant, move_idx=True)
                else:
                    graph.insert_before(node, fake_dequant, move_idx=True)

        # Third, we rewrite neg/sub/div using supported functions(e.g add, mul)
        def _is_neg_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind == 'neg'
                    and torch.is_floating_point(cur_module.prev_tensors[0])
                    and self.layerwise_config.get(node.unique_name, True)
                )

        neg_nodes = graph.filter_forward_nodes(_is_neg_node)
        log.info(f'rewriting neg for {[node.unique_name for node in neg_nodes]}')
        for idx, node in enumerate(neg_nodes):
            node.module.func_type = '__mul__'
            node.module.kind = 'mul'

            full_name_parts = node.module.full_name.split('.')
            full_name_parts[-1] = node.module.func_type

            node.module.full_name = '.'.join(full_name_parts)

            with override_current_trace_graph(graph):
                node.module.parse_args(node.prev_tensors[0], -1.0)

        def _is_div_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            # Current, only the following condition could be handled.
            #   a / constant => a * (1 / constant)
            if cur_class == TraceFunction:
                return (
                    (cur_module.kind == 'truediv' or cur_module.func_type in ('div', 'div_'))
                    and len(cur_module.prev_tensors) == 1
                    and torch.is_floating_point(cur_module.prev_tensors[0])
                    and cur_module.func_type != '__rtruediv__'
                    and torch.is_floating_point(node.next_tensors[0])
                    and node.prev_nodes[0].kind() not in ('size', 'shape')
                    and self.layerwise_config.get(node.unique_name, True)
                )

        div_nodes = graph.filter_forward_nodes(_is_div_node)
        log.info(f'rewriting div for {[node.unique_name for node in div_nodes]}')
        for idx, node in enumerate(div_nodes):
            op_type = node.module.func_type
            inplace = op_type in ('__itruediv__', 'itruediv', 'div_')

            if inplace:
                node.module.func_type = '__imul__'
            else:
                node.module.func_type = '__mul__'

            node.module.kind = 'mul'

            full_name_parts = node.module.full_name.split('.')
            full_name_parts[-1] = node.module.func_type

            node.module.full_name = '.'.join(full_name_parts)

            # Here we make a simple guess, if dot is in the string, then it's a floating number.
            # Otherwise, it is an integral number.
            if '.' in node.module.args_string_no_self:
                other_arg = float(node.module.args_string_no_self)
            else:
                other_arg = int(node.module.args_string_no_self)

            with override_current_trace_graph(graph):
                node.module.parse_args(node.prev_tensors[0], 1.0 / other_arg)

        def _is_sub_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind == 'sub'
                    and torch.is_floating_point(cur_module.prev_tensors[0])
                    and torch.is_floating_point(node.next_tensors[0])
                    and self.layerwise_config.get(node.unique_name, True)
                )

        sub_nodes = graph.filter_forward_nodes(_is_sub_node)
        log.info(f'rewriting sub for {[node.unique_name for node in sub_nodes]}')
        for idx, node in enumerate(sub_nodes):
            op_type = node.module.func_type

            full_name_parts = node.module.full_name.split('.')
            full_name_parts[-1] = node.module.func_type

            if len(node.module.prev_tensors) == 1 and node.module.func_type != '__rsub__':
                node.module.func_type = '__add__'
                node.module.kind = 'add'
                node.module.full_name = '.'.join(full_name_parts)

                # Here we make a simple guess, if dot is in the string, then it's a floating number.
                # Otherwise, it is an integral number.
                if '.' in node.module.args_string_no_self or 'e' in node.module.args_string_no_self:
                    other_arg = float(node.module.args_string_no_self)
                else:
                    other_arg = int(node.module.args_string_no_self)

                with override_current_trace_graph(graph):
                    node.module.parse_args(node.prev_tensors[0], -other_arg)
            elif len(node.module.prev_tensors) == 2 and len(node.prev_nodes) == 2:
                new_fullname_parts = copy.deepcopy(full_name_parts)
                new_fullname_parts[-1] = '__mul__'
                new_fullname = '.'.join(new_fullname_parts)
                current_tensor = node.prev_tensors[0]
                input_node = node.prev_nodes[1]
                input_tensor = node.prev_tensors[1]
                output_tensor = input_tensor * -1

                with override_current_trace_graph(graph):
                    trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1)
                    graph.insert_between(input_node, node, trace_func, [output_tensor], True)

                    node.module.func_type = '__add__'
                    node.module.kind = 'add'

                    full_name_parts[-1] = node.module.func_type

                    node.module.full_name = '.'.join(full_name_parts)

                    node.module.parse_args(current_tensor, output_tensor)

            elif node.module.func_type == '__rsub__' and len(node.module.prev_tensors) == 1:
                new_fullname_parts = copy.deepcopy(full_name_parts)
                new_fullname_parts[-1] = '__mul__'
                new_fullname = '.'.join(new_fullname_parts)
                input_node = node.prev_nodes[0]
                input_tensor = node.prev_tensors[0]
                output_tensor = input_tensor * -1

                if '.' in node.module.args_string_no_self:
                    other_arg = float(node.module.args_string_no_self)
                else:
                    other_arg = int(node.module.args_string_no_self)

                with override_current_trace_graph(graph):
                    trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1)
                    graph.insert_between(input_node, node, trace_func, [output_tensor], True)

                    node.module.func_type = '__radd__'
                    node.module.kind = 'add'

                    full_name_parts[-1] = node.module.func_type

                    node.module.full_name = '.'.join(full_name_parts)

        # Then, we write torch.stack nodes to torch.unsqueeze + torch.cat
        def _is_stack_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            return (
                cur_class == TraceFunction
                and cur_module.kind == 'stack'
                and torch.is_floating_point(node.next_tensors[0])
                and self.layerwise_config.get(node.unique_name, True)
            )

        stack_nodes = graph.filter_forward_nodes(_is_stack_node)

        for idx, node in enumerate(stack_nodes):
            args = getattr(node.module, 'args_string_no_self', '')
            if ',' in args:
                log.error('rewrite doesn\'t support multiple args for torch.stack')
                assert False
            if len(args) > 0:
                if '=' in args:
                    k, v = args.split('=')
                    assert k in ('axis', 'dim')
                    dim = int(v)
                    if k == 'axis':
                        node.module.args_template_no_self = node.module.args_template_no_self.replace('axis', 'dim')
                        node.module.args_template = node.module.args_template.replace('axis', 'dim')
                        node.module.update_args_string()
                else:
                    dim = int(args)
            else:
                dim = 0
            unique_prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values()
            for n in unique_prev_nodes:
                shared_tensors = list(set(node.prev_tensors).intersection(set(n.next_tensors)))
                if len(shared_tensors) == 0:
                    log.debug('tensor rewrite already done, skipping')
                    continue
                for t in shared_tensors:
                    with override_current_trace_graph(graph):
                        trace_func = TraceFunction('torch.unsqueeze', prefix='fuse_').parse_args(t, dim)
                    next_tensors = [torch.unsqueeze(t, dim)]
                    graph.insert_between(n, node, trace_func, next_tensors, move_idx=True, tensor_ptrs=set([id(t)]))

            node.module.func_type = 'cat'
            node.module.kind = 'cat'

            full_name_parts = node.module.full_name.split('.')
            full_name_parts[-1] = node.module.func_type

            node.module.full_name = '.'.join(full_name_parts)

        # Next, we rewrite add/mul/cat with one float32 output using torch.nn.quantized.FloatFunctional
        def _is_convertible_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind in ('add', 'mul', 'cat')
                    and torch.is_floating_point(node.next_tensors[0])
                    and self.layerwise_config.get(node.unique_name, True)
                    and all((torch.is_floating_point(pt) for pt in node.prev_tensors))
                )

        convertible_nodes = graph.filter_forward_nodes(_is_convertible_node)
        unfusable_add_nodes = []  # noqa: F841
        log.info(f'rewriting add/mul/cat for {[node.unique_name for node in convertible_nodes]}')
        for idx, node in enumerate(convertible_nodes):
            old_full_name = node.module.full_name  # noqa: F841
            old_is_class = node.module.is_class  # noqa: F841

            op_kind = node.module.kind
            float_functional = nnq.FloatFunctional()

            module_name = f'float_functional_simple_{idx}'

            graph.module_unique_name_dict[id(float_functional)] = module_name
            graph.module_original_name_dict[id(float_functional)] = module_name

            float_functional_cls = type(float_functional)
            module_constructor_lines[id(float_functional)] = f'{qualified_name(float_functional_cls, short=True)}()'

            new_node = TraceNode(float_functional, cur_graph=graph)
            graph.nodes_map[new_node.unique_name] = new_node
            graph.other_init_nodes.append(new_node)

            node.module.is_class = False
            prev_tensor_size = len(node.prev_tensors)
            if op_kind in ('add', 'mul'):
                if prev_tensor_size == 2:
                    if node.prev_nodes[1].prev_nodes[0].kind() not in ('shape', 'size'):
                        op_type = op_kind
                    else:
                        op_type = f'{op_kind}_scalar'
                elif prev_tensor_size == 1:
                    op_type = f'{op_kind}_scalar'
                else:
                    log.error(f'Unknown add/mul type for {node.unique_name}, prev tensor size: {prev_tensor_size}')
                    assert False
            else:
                # Don't check anything for other OPs.
                # It is simply too complex for us.
                op_type = op_kind
            node.module.func_type = op_type
            node.module.full_name = f'self.{module_name}.{op_type}'
            # Inplace operations
            if node.module.func_type in ['__iadd__', '__imul__', 'add_', 'mul_']:
                node.module.add_alias(node.module.tensor_names[0])
                q = queue.Queue()
                q.put(node.prev_nodes[0])
                while not q.empty():
                    n = q.get()
                    if type(n.module) is TraceFunction:
                        prev_aliases = n.module.get_aliases()
                        if prev_aliases is not None:
                            for pa in reversed(prev_aliases):
                                node.module.add_alias(pa, head=True)
                    else:
                        if getattr(n.module, 'inplace', False):
                            q.put(n.prev_nodes[0])
                            node.module.add_alias(n.prev_node_unique_name(0), head=True)

            # We need to convert radd to normal add here
            if node.module.func_type in ['__radd__', '__rmul__']:
                if '=' in node.module.args_string_no_self or ', ' in node.module.args_string_no_self:
                    log.error(f'Don\'t know how to translate {node.module.args_string_no_self} for __radd__/__rmul__')
                    assert False
                if prev_tensor_size == 1:
                    # Here we make a simple guess, if dot is in the string, then it's a floating number.
                    # Otherwise, it is an integral number.
                    if '.' in node.module.args_string_no_self or 'e' in node.module.args_string_no_self:
                        other_arg = float(node.module.args_string_no_self)
                    else:
                        other_arg = int(node.module.args_string_no_self)

                    with override_current_trace_graph(graph):
                        node.module.parse_args(node.prev_tensors[0], other_arg)
                else:
                    with override_current_trace_graph(graph):
                        # It is even simple here. We only need to swap the order of the tensors.
                        node.module.parse_args(node.prev_tensors[1], node.prev_tensors[0])

        # Rewrite torch.clamp_{min,max} to torch.clamp
        def _is_clamp_min_max_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind in ('clamp_min', 'clamp_max')
                    and torch.is_floating_point(node.next_tensors[0])
                    and self.layerwise_config.get(node.unique_name, True)
                )

        clamp_min_max_nodes = graph.filter_forward_nodes(_is_clamp_min_max_node)
        log.info(f'rewriting clamp_{{min,max}} for {[node.unique_name for node in convertible_nodes]}')
        for node in clamp_min_max_nodes:
            kw = node.module.kind[-3:]
            _parse_args = eval(f'lambda {kw}, *args, **kwargs: {kw}')

            val = eval(f'_parse_args({node.module.args_string_no_self})')

            if kw != 'min' or val != 0.0:
                continue

            if kw == 'min':
                arg_str = f'{val}'
            else:
                arg_str = f'None, {val}'

            sub_str = f'_{kw}'
            node.module.kind = node.module.kind.replace(sub_str, '')
            node.module.func_type = node.module.func_type.replace(sub_str, '')
            node.module.full_name = '.'.join(node.module.full_name.split('.')[:-1] + [node.module.func_type])

            node.module.args_template_no_self = arg_str
            node.module.args_template = f'{{}}, {arg_str}'
            node.module.update_args_string()

        # Rewrite torch.clamp to relu, relu6 and clamp_{min, max}
        def _is_clamp_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind == 'clamp'
                    and torch.is_floating_point(node.next_tensors[0])
                    and self.layerwise_config.get(node.unique_name, True)
                )

        clamp_nodes = graph.filter_forward_nodes(_is_clamp_node)
        log.info(f'rewriting clamp for {[node.unique_name for node in convertible_nodes]}')
        for node in clamp_nodes:

            def _parse_args(min=None, max=None, *args, **kwargs):
                return min, max

            min, max = eval(f'_parse_args({node.module.args_string_no_self})')

            kind = None
            if min == 0.0:
                if max is None:
                    kind = 'relu'
                elif max == 6.0:
                    kind = 'relu6'

            if kind is None:
                if max is None:
                    kind = 'clamp_min'
                elif min is None:
                    kind = 'clamp_max'
                else:
                    kind = 'clamp_with_fusion'

            if kind in ('clamp_min', 'clamp_max'):
                node.module.kind = kind
                node.module.func_type = node.module.func_type.replace('clamp', kind)
                node.module.full_name = '.'.join(node.module.full_name.split('.')[:-1] + [node.module.func_type])

                if max is None:
                    arg_str = f'{min}'
                else:
                    arg_str = f'{max}'
            elif kind == 'clamp_with_fusion':
                node.module.is_class = False
                node.module.kind = kind
                node.module.func_type = node.module.func_type.replace('clamp', kind)
                node.module.full_name = f'tinynn.graph.quantization.utils.{node.module.func_type}'

                arg_str = f'{min}, {max}'
            else:
                inplace = node.module.func_type == f'{node.module.kind}_'
                node.module.kind = kind
                node.module.func_type = kind
                node.module.full_name = f'torch.nn.functional.{kind}'

                if inplace:
                    arg_str = 'inplace=True'
                else:
                    arg_str = ''

            node.module.args_template_no_self = arg_str
            node.module.args_template = f'{{}}, {arg_str}'
            node.module.update_args_string()

        # Rewrite other fusable functions
        # e.g. add_relu(x, y) =>
        #            r = torch.add(x, y)
        #            r = torch.nn.functional.relu(r)
        def _is_add_relu_fusable_node(node: TraceNode, custom_data) -> bool:
            cur_module = node.module
            cur_class = type(cur_module)
            visited_nodes = [node]

            if self.layerwise_config.get(node.unique_name, True) is False:
                return False

            if cur_class == TraceFunction:
                # The intermediate result cannot be used if fused.
                # So we don't fuse the nodes under such circumstances.
                if cur_module.kind != 'add' or len(node.next_nodes) != 1:
                    return False
                # The input for the add operations should be two tensors.
                if len(node.prev_tensors) != 2:
                    return False
                # We accept inplace operations for both add and relu.
                # The inplace property could be elinimated because we track the tensors
                # instead of their names.
                next_node = node.next_nodes[0]
                next_module = next_node.module
                next_class = type(next_module)
                if next_class == TraceFunction:
                    fuse = next_module.kind == 'relu'
                else:
                    while next_class == nn.Identity:
                        cur_node = next_node
                        visited_nodes.append(cur_node)
                        if len(cur_node.next_nodes) != 1:
                            return False
                        next_node = cur_node.next_nodes[0]
                        next_module = next_node.module
                        next_class = type(next_module)

                    fuse = next_class.__name__ == 'ReLU'

                if not fuse:
                    return False

                if type(next_node.module) is TraceFunction:
                    inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string
                else:
                    inplace = getattr(next_node.module, 'inplace', False)

                # Inplace check
                # If add is inplace and relu is not inplace, we need to ensure that all the aliases of
                # the first operand of add are not used when relu is called.
                if not inplace:
                    aliases = cur_module.get_aliases()
                    if aliases:
                        q = queue.Queue()
                        q.put(node.prev_nodes[0])
                        while not q.empty():
                            n = q.get()
                            last_order = max((x.forward_order for x in n.next_nodes))
                            if last_order > node.forward_order:
                                fuse = False
                                break
                            if type(n.module) is TraceFunction and n.module.get_aliases():
                                q.put(n.prev_nodes[0])
                            elif getattr(n.module, 'inplace', False):
                                q.put(n.prev_nodes[0])

                return fuse

        add_relu_fusable_nodes = graph.filter_forward_nodes(_is_add_relu_fusable_node)
        for node in add_relu_fusable_nodes:
            full_name = node.module.full_name.replace('add', 'add_relu')
            next_node = node.next_nodes[0]
            kind = 'add_relu'
            func_type = kind
            is_class = False
            nodes_to_fuse = [node, next_node]
            while next_node.type() is nn.Identity:
                next_node = next_node.next_nodes[0]
                nodes_to_fuse.append(next_node)
            if type(next_node.module) is TraceFunction:
                inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string
            else:
                inplace = next_node.module.inplace
            graph.fuse_nodes_to_func(nodes_to_fuse, full_name, kind, func_type, is_class)
            # Propagate aliases for inplace nodes
            if inplace:
                aliases = node.module.get_aliases()
                if aliases:
                    node.module.add_alias(node.module.tensor_names[0])
            else:
                node.module.aliases = None

        # Rewrite relu, relu6 as nn.ReLU() and nn.ReLU6() for Module fusable rules
        def _is_functional_rewrite_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return cur_module.kind in FUNCTIONAL_MODULE_MAPPING and self.layerwise_config.get(
                    node.unique_name, True
                )

        func_nodes_to_rewrite = graph.filter_forward_nodes(_is_functional_rewrite_node)
        log.info(f'rewriting functional to module for {[node.unique_name for node in func_nodes_to_rewrite]}')
        for idx, node in enumerate(func_nodes_to_rewrite):
            kind = node.module.kind
            inplace = node.module.func_type == f'{kind}_' or 'True' in node.module.args_string
            klass = FUNCTIONAL_MODULE_MAPPING[kind]
            arguments = getattr(klass, '__constants__', None)
            if arguments is None:
                new_func = klass()
            elif node.module.kind in ('relu', 'relu6', 'silu', 'hardswish'):
                new_func = klass(inplace=inplace)
            elif node.module.kind in ('elu', 'leaky_relu'):
                if hasattr(node.module, 'args_string_no_self'):

                    def _parse_args(alpha=1.0, *args, **kwargs):  # noqa: F811
                        return alpha

                    alpha = eval(f'_parse_args({node.module.args_string_no_self})')
                    if node.module.kind == 'leaky_relu':
                        new_func = klass(alpha, inplace=inplace)
                    else:
                        new_func = klass(alpha, inplace=inplace)
                else:
                    alpha = None
                    if node.module.kind == 'leaky_relu':
                        new_func = klass(inplace=inplace)
                    else:
                        new_func = klass()
            elif node.module.kind == 'prelu':
                weight_t = node.prev_tensors[1]
                weight_node = node.prev_nodes[1]
                if weight_node.type() is torch_q.QuantStub:
                    weight_node = weight_node.prev_nodes[0]

                if weight_node.type() is not ConstantNode or not weight_node.module.is_parameter:
                    log.warning('Rewrite for F.prelu(x, buffer) to nn.PReLU is skipped as it changes the semantics')
                    continue

                num_parameters = weight_t.nelement()
                new_func = klass(num_parameters)
                new_func.weight.data.copy_(node.prev_tensors[1])

                # Drop last input
                node.prev_tensors.pop()
                last_node = node.prev_nodes.pop()
                last_node.next_nodes.remove(node)

                # Remove unused forward functions iteratively
                new_last_node = last_node
                while new_last_node is not None:
                    last_node = new_last_node
                    if (
                        len(last_node.next_nodes) == 0
                        and len(last_node.prev_nodes) < 2
                        and last_node in graph.forward_nodes
                    ):
                        if len(last_node.prev_nodes) == 1:
                            new_last_node = last_node.prev_nodes[0]
                        else:
                            new_last_node = None
                        graph.remove_node(last_node)
                    else:
                        new_last_node = None
            elif node.module.kind == 'glu':
                if hasattr(node.module, 'args_string_no_self'):

                    def _parse_args_dim(dim=-1, *args, **kwargs):  # noqa: F811
                        return dim

                    dim = eval(f'_parse_args_dim({node.module.args_string_no_self})')
                    new_func = klass(dim)
                else:
                    dim = None
                    new_func = klass()
            else:
                raise NotImplementedError(f"Don't know how to parse {klass.__name__} with argument {arguments}")

            graph.module_unique_name_dict[id(new_func)] = f'rewritten_{kind}_{idx}'
            graph.module_original_name_dict[id(new_func)] = f'rewritten_{kind}_{idx}'

            relu_cls = type(new_func)
            if inplace:
                arg_str = 'inplace=True'
            else:
                arg_str = ''

            if node.module.kind in ('elu', 'leaky_relu') and alpha is not None:
                if arg_str:
                    arg_str = f'{alpha}, {arg_str}'
                else:
                    arg_str = f'{alpha}'

            if node.module.kind == 'prelu':
                if num_parameters != 1:
                    arg_str = f'{num_parameters}'
                else:
                    arg_str = ''

            if node.module.kind == 'glu':
                if dim is not None and dim != -1:
                    arg_str = f'{dim}'
                else:
                    arg_str = ''

            module_constructor_lines[id(new_func)] = f'{qualified_name(relu_cls)}({arg_str})'
            graph.replace_node_module(node, new_func)

        # Rewrite dropout as nn.Dropout() for models in training mode
        def _is_dropout_functional_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return cur_module.kind == 'dropout' and self.layerwise_config.get(node.unique_name, True)

        def _dropout_args(p=0.5, training=True, inplace=False):
            return p, training, inplace

        dropout_nodes_to_rewrite = graph.filter_forward_nodes(_is_dropout_functional_node)
        log.info(f'rewriting dropout for {[node.unique_name for node in dropout_nodes_to_rewrite]}')
        for idx, node in enumerate(dropout_nodes_to_rewrite):
            args = getattr(node.module, 'args_string_no_self', '')
            p, _, inplace = eval(f'_dropout_args({args})')
            kind = node.module.kind
            inplace = node.module.func_type == f'{kind}_' or inplace
            dropout_cls = nn.Dropout
            new_dropout = dropout_cls(p, inplace=inplace)

            graph.module_unique_name_dict[id(new_dropout)] = f'rewritten_{kind}_{idx}'
            graph.module_original_name_dict[id(new_dropout)] = f'rewritten_{kind}_{idx}'

            if inplace:
                arg_str = f'{p}, inplace={inplace}'
            else:
                arg_str = f'{p}'

            module_constructor_lines[id(new_dropout)] = f'{qualified_name(dropout_cls)}({arg_str})'
            graph.replace_node_module(node, new_dropout)

        # Add contiguous nodes for partially-supported OPs
        # Some of the operations support quantization, but they only accept contiguous input tensors.
        def _is_partially_quantizable(node, custom_data):
            if self.layerwise_config.get(node.unique_name, True) is False:
                return False
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == ConstantNode:
                return False
            elif cur_class == TraceFunction:
                return cur_module.kind in ('pad',)
            else:
                return cur_class in (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d, nn.ZeroPad2d)

        if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
            partially_supported_nodes = graph.filter_forward_nodes(_is_partially_quantizable)
            for idx, node in enumerate(partially_supported_nodes):
                for n in node.prev_nodes[:1]:
                    shared_tensors = list(set(node.prev_tensors).intersection(set(n.next_tensors)))
                    if len(shared_tensors) > 1:
                        log.error('rewrite for partially-supported ops supports with nodes with exact one input')
                        assert False
                    with override_current_trace_graph(graph):
                        trace_func = TraceFunction('torch.Tensor.contiguous', True, prefix='fuse_').parse_args(
                            shared_tensors[0]
                        )
                    next_tensors = [x.contiguous() for x in shared_tensors]
                    graph.insert_between(n, node, trace_func, next_tensors, True)

        # Remove non-leaf `.data` nodes
        def _is_non_leaf_data_nodes(node, custom_data):
            if self.layerwise_config.get(node.unique_name, True) is False:
                return False
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return cur_module.kind == 'data' and cur_module.is_property and len(node.next_nodes) > 0
            return False

        non_leaf_data_nodes = graph.filter_forward_nodes(_is_non_leaf_data_nodes)
        for idx, node in enumerate(non_leaf_data_nodes):
            graph.remove_node(node)

        # Handle PoolNd with kernel_size=1
        def _is_pool_nd_with_one_kernel_size(node, custom_data):
            if self.layerwise_config.get(node.unique_name, True) is False:
                return False
            cur_module = node.module
            cur_class = type(cur_module)
            kernel_size, stride = None, None
            if cur_class == TraceFunction:
                if cur_module.kind in ('avg_pool1d', 'avg_pool2d', 'max_pool1d', 'max_pool2d'):

                    def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs):
                        return kernel_size, stride

                    kernel_size, stride = eval(f'_avgpool_kernel_size_and_stride({cur_module.args_string_no_self})')
            else:
                if cur_class in (nn.AvgPool1d, nn.AvgPool2d, nn.MaxPool1d, nn.MaxPool2d):
                    kernel_size = cur_module.kernel_size
                    stride = cur_module.stride

            if kernel_size is not None:
                if isinstance(kernel_size, (tuple, list)):
                    is_match = all((ks == 1 for ks in kernel_size))
                else:
                    is_match = kernel_size == 1
            else:
                is_match = False

            if is_match:
                custom_data.append((node, kernel_size, stride))
                return True
            else:
                return False

        pool_one_kernel_size_nodes = []
        graph.filter_forward_nodes(_is_pool_nd_with_one_kernel_size, pool_one_kernel_size_nodes)
        for idx, (node, kernel_size, stride) in enumerate(pool_one_kernel_size_nodes):
            slices = [slice(None)] * 2
            t = node.prev_tensors[0]
            dim = len(t.shape)
            if not isinstance(stride, (list, tuple)):
                stride = [stride] * (dim - 2)
            for s in stride:
                if s == 1 or s is None:
                    slices.append(slice(None))
                else:
                    slices.append(slice(None, None, s))

            with override_current_trace_graph(graph):
                new_func = TraceFunction('torch.Tensor.__getitem__', True).parse_args(t, slices)

                graph.module_unique_name_dict[id(new_func)] = f'rewritten_pool_{idx}'
                graph.module_original_name_dict[id(new_func)] = f'rewritten_pool_{idx}'

                graph.replace_node_module(node, new_func)

        # Rewrite Linear-BatchNorm1d structure to Conv2d-BatchNorm2d
        is_rewrite_to_fuse = functools.partial(
            self.is_fusable,
            current_rules=load_processed_rewrite_to_fuse_rules(),
            check_node_quantized=False,
            graph=graph,
            layerwise_config_default=True,
            use_original_name=False,
        )
        custom_data = ([], set())
        graph.filter_forward_nodes(is_rewrite_to_fuse, custom_data, reverse=True)
        rewrite_fuse_names_list = custom_data[0]
        log.debug(f'found names_list that need to rewrite for fusing: {rewrite_fuse_names_list}')

        for idx, names in enumerate(reversed(rewrite_fuse_names_list)):
            # case fc-bn1d
            assert len(names) == 2, 'the rewrite nodes list length != 2'

            node_fc = graph.nodes_map[names[0]]
            node_bn1d = graph.nodes_map[names[1]]
            mod_fc = node_fc.module
            mod_bn = node_bn1d.module

            assert type(mod_fc) is nn.Linear and type(mod_bn) is nn.BatchNorm1d, "the rewrite struct is\'t [fc-bn1d]"

            if len(node_fc.prev_tensors[0].shape) != 2:
                log.debug('the [fc-bn]\'s input dimension != 2')
                continue
            # for fc-bn1d, rewrite [fc-bn1d] to [conv2d-bn2d]
            new_conv2d = torch.nn.Conv2d(
                in_channels=mod_fc.in_features,
                out_channels=mod_fc.out_features,
                kernel_size=[1, 1],
                bias=mod_fc.bias is not None,
            )
            fc_weight = mod_fc.weight
            new_conv2d.weight = nn.Parameter(torch.reshape(fc_weight, [fc_weight.shape[0], fc_weight.shape[1], 1, 1]))
            if mod_fc.bias is not None:
                new_conv2d.bias = mod_fc.bias
            graph.module_unique_name_dict[id(new_conv2d)] = f'rewritten_conv2d_bn2d_conv2d_{idx}'
            graph.module_original_name_dict[id(new_conv2d)] = f'rewritten_conv2d_bn2d_conv2d_{idx}'

            new_bn2d = torch.nn.BatchNorm2d(
                mod_bn.num_features,
                mod_bn.eps,
                mod_bn.momentum,
                affine=mod_bn.affine,
                track_running_stats=mod_bn.track_running_stats,
            )
            new_bn2d.load_state_dict(mod_bn.state_dict())
            graph.module_unique_name_dict[id(new_bn2d)] = f'rewritten_conv2d_bn2d_bn2d_{idx}'
            graph.module_original_name_dict[id(new_bn2d)] = f'rewritten_conv2d_bn2d_bn2d_{idx}'

            # replace new node, then insert reshape before new_conv2d and after new_bn2d
            with override_current_trace_graph(graph):
                graph.replace_node_module(node_fc, new_conv2d)
                graph.replace_node_module(node_bn1d, new_bn2d)

                prev_func = TraceFunction('torch.Tensor.__getitem__', prefix='rewritten_conv2d_bn2d_').parse_args(
                    node_fc.prev_tensors[0], (Ellipsis, None, None)
                )
                next_func = TraceFunction('torch.flatten', prefix='rewritten_conv2d_bn2d_').parse_args(
                    node_bn1d.next_tensors[0], 1
                )
            # expand the tensor shape between fc new_conv2d and new_bn2d
            node_fc.next_tensors[0].unsqueeze_(2).unsqueeze_(2)
            node_bn1d.prev_tensors[0].unsqueeze_(2).unsqueeze_(2)
            node_bn1d.next_tensors[0].unsqueeze_(2).unsqueeze_(2)

            prev_out = node_fc.prev_tensors[0][..., None, None]
            graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out], True)
            next_out = torch.flatten(node_bn1d.next_tensors[0], 1)
            graph.insert_after(node_bn1d, next_func, [next_out])

        # Rewrite BatchNorm1d to BatchNorm2d
        def _is_batch_norm_1d(node, custom_data):
            if self.layerwise_config.get(node.unique_name, True) is False:
                return False
            cur_module = node.module
            cur_class = type(cur_module)
            if len(node.prev_nodes) != 1:
                return False

            if node.prev_nodes[0].kind() in ('conv1d', nn.Conv1d):
                return False

            if cur_class == TraceFunction:
                return cur_module.kind == 'batch_norm' and node.prev_tensors[0].ndim == 3
            else:
                return cur_class == nn.BatchNorm1d

        batch_norm_1d_nodes = graph.filter_forward_nodes(_is_batch_norm_1d)
        for idx, node in enumerate(batch_norm_1d_nodes):
            mod = node.module
            if type(mod) is nn.BatchNorm1d:
                new_bn = torch.nn.BatchNorm2d(
                    mod.num_features,
                    mod.eps,
                    mod.momentum,
                    affine=mod.affine,
                    track_running_stats=mod.track_running_stats,
                )
                new_bn.load_state_dict(mod.state_dict())

                graph.module_unique_name_dict[id(new_bn)] = f'rewritten_bn2d_{idx}'
                graph.module_original_name_dict[id(new_bn)] = f'rewritten_bn2d_{idx}'

                with override_current_trace_graph(graph):
                    graph.replace_node_module(node, new_bn)

                    prev_func = TraceFunction('torch.unsqueeze', prefix='rewritten_bn2d_').parse_args(
                        node.prev_tensors[0], 2
                    )
                    next_func = TraceFunction('torch.squeeze', prefix='rewritten_bn2d_').parse_args(
                        node.next_tensors[0], 2
                    )

                node.next_tensors[0].unsqueeze_(2)

                prev_out = torch.unsqueeze(node.prev_tensors[0], 2)
                graph.insert_between(node.prev_nodes[0], node, prev_func, [prev_out], True)
                next_out = torch.squeeze(node.next_tensors[0], 2)
                graph.insert_after(node, next_func, [next_out])

        type_dict = {}
        for n in graph.forward_nodes:
            if n.type() not in (torch_q.QuantStub, torch_q.DeQuantStub):
                self.layerwise_config.setdefault(n.unique_name, True)
                kind = n.kind()
                type_str = kind if isinstance(kind, str) else kind.__name__
                type_dict[n.unique_name] = type_str

        fuse_mapping = {}
        if self.fused_layerwise_config:
            # Find all fusable nodes
            if type(self).__name__ == 'QATQuantizer':
                processed_all_rules = load_processed_all_qat_rules()
            else:
                processed_all_rules = load_processed_all_ptq_rules()

            is_fusable = functools.partial(
                self.is_fusable,
                current_rules=processed_all_rules,
                check_node_quantized=False,
                use_original_name=False,
            )

            custom_data = ([], set())
            graph.filter_forward_nodes(is_fusable, custom_data, reverse=True)
            activ_names = custom_data[0]

            for idx, names in enumerate(reversed(activ_names)):
                types = [graph.nodes_map[n].kind() for n in names]
                types = [x.__name__ if not isinstance(x, str) else x for x in types]
                types_str = ', '.join(types)
                types_str = f'({types_str})'
                type_dict[names[0]] = types_str
                for name in names[1:]:
                    fuse_mapping[name] = names[0]
                    self.layerwise_config[name] = self.layerwise_config.get(names[0], True)

        for n, t in type_dict.items():
            self.layerwise_config.yaml_add_eol_comment(f'type: {t}', n)

        skip_types = set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) == 1)
        for module_cls, action in self.quantize_op_action.items():
            if action == 'rewrite':
                skip_types.add(module_cls)
        if self.set_quantizable_op_stats:
            skip_types |= set(KNOWN_QSTATS.keys())
        skip_types_prev = skip_types | set(k[-1] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1)
        skip_types_next = skip_types | set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1)

        # Add quant/dequant nodes for non-quantizable OPs
        disable_quantize_op_list = UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.copy()
        for module_cls, action in self.quantize_op_action.items():
            if action in ('disable', 'rewrite'):
                disable_quantize_op_list[module_cls] = None

        def _is_rewritable_lstm_node(node, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            return cur_class == nn.LSTM

        if self.quantize_op_action.get(nn.LSTM, 'enable') == 'rewrite':
            rewritable_lstm_nodes = graph.filter_forward_nodes(_is_rewritable_lstm_node)
            fake_dequant_cls = torch_q.DeQuantStub
            for idx, node in enumerate(rewritable_lstm_nodes):
                assert node.module.num_layers == 1, "Quantization rewrite for multi-layer LSTM is not yet supported"
                assert not node.module.bidirectional, "Quantization rewrite for bidirectional LSTM is not yet supported"

                cell_state = node.next_tensors[1][1]

                fake_dequant = fake_dequant_cls()

                fake_dequant_name = f'fake_dequant_rewrite_{idx}'

                graph.module_unique_name_dict[id(fake_dequant)] = fake_dequant_name
                graph.module_original_name_dict[id(fake_dequant)] = fake_dequant_name

                module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'

                new_node = graph.insert_new_after(
                    node, fake_dequant, [cell_state], [[1, 1]], before_node=node.next_nodes[0]
                )

                with override_current_trace_graph(graph):
                    size_func = TraceFunction(
                        'torch.Tensor.size', is_class=True, prefix='fake_dequant_rewrite_'
                    ).parse_args(new_node.next_tensors[0], -1)

                size_node = graph.insert_new_after(
                    new_node,
                    size_func,
                    [new_node.next_tensors[0]],
                    [None],
                    next_tensors=[torch.tensor(new_node.next_tensors[0].size(-1))],
                    before_node=node.next_nodes[0],
                )
                size_len = len(node.next_tensors[0].shape)

                if node.module.bidirectional:
                    with override_current_trace_graph(graph):
                        size_func = TraceFunction(
                            'torch.Tensor.__mul__', is_class=True, prefix='fake_dequant_rewrite_'
                        ).parse_args(size_node.next_tensors[0], 2)

                        size_node = graph.insert_new_after(
                            size_node,
                            size_func,
                            [size_node.next_tensors[0]],
                            [None],
                            next_tensors=[size_node.next_tensors[0] * 2],
                            before_node=node.next_nodes[0],
                        )

                with override_current_trace_graph(graph):
                    expand_func = TraceFunction(
                        'torch.Tensor.expand', is_class=True, prefix='fake_dequant_rewrite_'
                    ).parse_args(node.next_tensors[0], *((-1,) * (size_len - 1)), size_node.next_tensors[0])

                graph.insert_between(
                    node, node.next_nodes[0], expand_func, tensor_ptrs=[id(node.next_tensors[0])], move_idx=True
                )
                expand_node = graph.nodes_map[expand_func.unique_name]
                size_node.next_nodes.append(expand_node)
                expand_node.prev_nodes.append(node)
                expand_node.prev_tensors.append(size_node.next_tensors[0])
                expand_node.prev_indices.append(None)

        def _is_not_quantizable(node, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == ConstantNode:
                return False
            elif cur_class == TraceFunction:
                if node.type() in ('__truediv__', '__itruediv__', 'div', 'div_'):
                    if node.prev_nodes[0].kind() in ('shape', 'size'):
                        return False
                if node.type() in ('shape', 'device', 'size', 'dtype'):
                    if node.unique_name in self.layerwise_config:
                        del self.layerwise_config[node.unique_name]
                    return False
                if self.layerwise_config.get(node.unique_name, True) is False:
                    return True
                supported_version = disable_quantize_op_list.get(cur_module.kind, torch.__version__)
                return supported_version is None or LooseVersion(torch.__version__) < supported_version
            else:
                if isinstance(cur_module, (torch_q.QuantStub, torch_q.DeQuantStub)):
                    return False
                if self.layerwise_config.get(node.unique_name, True) is False:
                    return True
                unsupported_types = tuple(
                    k
                    for k, v in disable_quantize_op_list.items()
                    if type(k) is not str
                    and k not in Q_MODULES_MAPPING
                    and (v is None or LooseVersion(torch.__version__) < v)
                )
                return isinstance(cur_module, unsupported_types)

        unsupported_nodes = graph.filter_forward_nodes(_is_not_quantizable)
        for idx, node in enumerate(reversed(unsupported_nodes)):
            if node.unique_name in self.layerwise_config:
                if node.kind() not in skip_types and self.layerwise_config[node.unique_name]:
                    del self.layerwise_config[node.unique_name]

            node_map = dict()
            next_nodes = {n.unique_name: n for n in node.next_nodes}.values()
            for inner_idx, next_node in enumerate(next_nodes):
                prev_tensor_ptrs = []
                if type(next_node.module) is TraceFunction and next_node.module.is_property:
                    continue

                for pt in next_node.prev_tensors:
                    for nt in node.next_tensors:
                        if isinstance(nt, (list, tuple)):
                            for k, ntt in enumerate(nt):
                                if id(pt) == id(ntt):
                                    if id(pt) not in prev_tensor_ptrs:
                                        prev_tensor_ptrs.append(id(pt))
                                    break
                        elif id(pt) == id(nt):
                            if id(pt) not in prev_tensor_ptrs:
                                prev_tensor_ptrs.append(id(pt))
                            break

                for ptr_idx, ptr in enumerate(prev_tensor_ptrs):
                    if ptr in node_map:
                        fake_quant = node_map[ptr]

                        graph.insert_between(node, next_node, fake_quant, move_idx=True, tensor_ptrs=[ptr])
                    else:
                        fake_quant = torch_q.QuantStub()

                        fake_quant_name = f'fake_quant_inner_{idx}_{inner_idx}_{ptr_idx}'

                        graph.module_unique_name_dict[id(fake_quant)] = fake_quant_name
                        graph.module_original_name_dict[id(fake_quant)] = fake_quant_name

                        fake_quant_cls = type(fake_quant)
                        module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'

                        graph.insert_between(node, next_node, fake_quant, move_idx=True, tensor_ptrs=[ptr])
                        node_map[ptr] = graph.nodes_map[fake_quant_name]

        # Insert the DeQuantStub nodes before every input node of the unsupported ops
        for idx, node in enumerate(unsupported_nodes):
            fake_dequant_cls = torch_q.DeQuantStub
            assert node.rev_index is False
            node_map = dict()
            prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values()
            for inner_idx, prev_node in enumerate(prev_nodes):

                if prev_node.kind() in ('shape', 'device', 'size', 'dtype'):
                    continue

                prev_tensor_ptrs = []
                for pt in node.prev_tensors:
                    for nt in prev_node.next_tensors:
                        if isinstance(nt, (list, tuple)):
                            for k, ntt in enumerate(nt):
                                if id(pt) == id(ntt):
                                    if id(pt) not in prev_tensor_ptrs:
                                        prev_tensor_ptrs.append(id(pt))
                                    break
                        elif id(pt) == id(nt):
                            if id(pt) not in prev_tensor_ptrs:
                                prev_tensor_ptrs.append(id(pt))
                            break

                for ptr_idx, ptr in enumerate(prev_tensor_ptrs):
                    if ptr in node_map:
                        fake_dequant = node_map[ptr]

                        graph.insert_between(prev_node, node, fake_dequant, move_idx=True, tensor_ptrs=set([ptr]))
                    else:

                        fake_dequant = fake_dequant_cls()

                        fake_dequant_name = f'fake_dequant_inner_{idx}_{inner_idx}_{ptr_idx}'

                        graph.module_unique_name_dict[id(fake_dequant)] = fake_dequant_name
                        graph.module_original_name_dict[id(fake_dequant)] = fake_dequant_name

                        module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'

                        graph.insert_between(prev_node, node, fake_dequant, move_idx=True, tensor_ptrs=set([ptr]))
                        node_map[ptr] = graph.nodes_map[fake_dequant_name]

        # Remove consecutive dequant quant nodes
        def _is_consecutive_dequant_quant_nodes(node, custom_data):
            cur_type = node.type()
            if cur_type in (torch_q.QuantStub, torch_q.DeQuantStub):
                for next_node in node.next_nodes:
                    next_type = next_node.type()
                    if next_type in (torch_q.QuantStub, torch_q.DeQuantStub):
                        if cur_type != next_type:
                            if cur_type == torch_q.QuantStub:
                                target_node = None
                                if len(node.prev_nodes) == 1 and node.prev_nodes[0].kind() in skip_types_prev:
                                    target_node = node.prev_nodes[0]
                                elif (
                                    len(next_node.next_nodes) == 1 and next_node.next_nodes[0].kind() in skip_types_next
                                ):
                                    target_node = next_node.next_nodes[0]
                                if target_node is not None:
                                    self.layerwise_config.setdefault(target_node.unique_name, True)
                                    if self.layerwise_config[target_node.unique_name]:
                                        return False
                            custom_data.append((node, next_node))
                            return True
            return False

        consecutive_dequant_quant_nodes = []
        consecutive_branch_dequant_quant_nodes = []
        graph.filter_forward_nodes(_is_consecutive_dequant_quant_nodes, consecutive_dequant_quant_nodes)
        for node, next_node in consecutive_dequant_quant_nodes:
            if len(node.next_nodes) == 1 and len(next_node.prev_nodes) == 1:
                graph.remove_node(next_node)
                graph.remove_node(node)
            else:
                consecutive_branch_dequant_quant_nodes.append((node, next_node))

        # TODO: Support complex cases
        for node, next_node in consecutive_branch_dequant_quant_nodes:
            is_removable = all(
                (
                    n.type() in (torch_q.QuantStub, torch_q.DeQuantStub, 'shape', 'device', 'size', 'dtype')
                    and n.type() != node.type()
                    for n in node.next_nodes
                )
            )
            if is_removable:
                graph.remove_node(node)
                for n in node.next_nodes:
                    if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
                        graph.remove_node(n)

        # Process additional fusable nodes
        processed_extra_q_rules = load_processed_extra_q_rules()
        is_extra_fusable = functools.partial(
            self.is_fusable,
            current_rules=processed_extra_q_rules,
            check_node_quantized=False,
            use_original_name=False,
        )

        custom_data = ([], set())
        graph.filter_forward_nodes(is_extra_fusable, custom_data, reverse=True)
        activ_names = custom_data[0]
        log.debug(f'found nodes that cannot fuse: {activ_names}')

        for idx, names in enumerate(reversed(activ_names)):
            name = names[-1]
            node = graph.nodes_map[name]
            fake_quant = torch_q.QuantStub()

            graph.module_unique_name_dict[id(fake_quant)] = f'fake_activ_quant_{idx}'
            graph.module_original_name_dict[id(fake_quant)] = f'fake_activ_quant_{idx}'

            fake_quant_cls = type(fake_quant)
            module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'

            graph.insert_after(node, fake_quant)

            fake_dequant = torch_q.DeQuantStub()

            graph.module_unique_name_dict[id(fake_dequant)] = f'fake_activ_dequant_{idx}'
            graph.module_original_name_dict[id(fake_dequant)] = f'fake_activ_dequant_{idx}'

            fake_dequant_cls = type(fake_dequant)
            module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'

            graph.insert_after(node, fake_dequant)

        fused_clamps = []
        for names in activ_names:
            for name in names:
                node = graph.nodes_map[name]
                if node.kind() == 'clamp_with_fusion':
                    fused_clamps.append(name)

        # Rewrite unfused clamp_with_fusion to torch.clamp
        def _is_unfused_clamp_node(node: TraceNode, custom_data):
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class == TraceFunction:
                return (
                    cur_module.kind == 'clamp_with_fusion'
                    and torch.is_floating_point(node.next_tensors[0])
                    and node.unique_name not in fused_clamps
                )
            return False

        unfused_clamp_nodes = graph.filter_forward_nodes(_is_unfused_clamp_node)
        log.info(f'rewriting unfused_clamp for {[node.unique_name for node in unfused_clamp_nodes]}')
        for node in unfused_clamp_nodes:
            node.module.kind = kind
            node.module.func_type = node.module.func_type.replace('clamp_with_fusion', 'clamp')
            node.module.full_name = f'torch.{node.module.func_type}'

        # Optional tensor shape broadcasting for quantized binary ops
        def _is_broadcastable_binary_quantized_op_node(node: TraceNode, custom_data) -> bool:
            cur_module = node.module
            cur_class = type(cur_module)
            if cur_class != TraceFunction:
                return False
            return (
                cur_module.full_name.startswith('self.float_functional_simple_')
                and cur_module.func_type in ('add', 'mul', 'add_relu')
                and node.prev_tensors[0].shape != node.prev_tensors[1].shape
            )

        broadcastable_binary_quantized_op_nodes = graph.filter_forward_nodes(_is_broadcastable_binary_quantized_op_node)
        for node in broadcastable_binary_quantized_op_nodes:
            assert len(node.prev_nodes) == 2
            assert len(node.prev_tensors) == 2

            l_shape = list(node.prev_tensors[0].shape)
            r_shape = list(node.prev_tensors[1].shape)

            ref_index = None
            if len(l_shape) > len(r_shape):
                ref_index = 0
                r_shape = [1] * (len(l_shape) - len(r_shape)) + r_shape
            elif len(l_shape) < len(r_shape):
                ref_index = 1
                l_shape = [1] * (len(r_shape) - len(l_shape)) + l_shape

            for l_dim, r_dim in zip(l_shape, r_shape):
                if l_dim > r_dim:
                    if ref_index in (None, 0) and r_dim == 1:
                        ref_index = 0
                    else:
                        ref_index = -1
                        break
                elif l_dim < r_dim:
                    if ref_index in (None, 1) and l_dim == 1:
                        ref_index = 1
                    else:
                        ref_index = -1
                        break

            if ref_index >= 0:
                src_index = 1 - ref_index

                with override_current_trace_graph(graph):
                    trace_func = TraceFunction('torch.Tensor.expand_as', True, prefix='fuse_').parse_args(
                        node.prev_tensors[src_index], node.prev_tensors[ref_index]
                    )
                next_tensors = [node.prev_tensors[src_index].expand_as(node.prev_tensors[ref_index])]
                graph.insert_between(node.prev_nodes[src_index], node, trace_func, next_tensors, True)

                new_node = graph.nodes_map[trace_func.unique_name]
                new_node.prev_nodes.append(node.prev_nodes[ref_index])
                new_node.prev_tensors.append(node.prev_tensors[ref_index])
                new_node.prev_indices.append(node.prev_indices[ref_index])
                node.prev_nodes[ref_index].next_nodes.append(new_node)
            else:
                new_indices = []
                for i in range(2):
                    if node.prev_indices[i] is None:
                        new_indices.append(i)
                    elif isinstance(node.prev_indices[i], (tuple, list)):
                        new_indices.append(node.prev_indices[i] + [i])
                    else:
                        new_indices.append([node.prev_indices[i], i])
                with override_current_trace_graph(graph):
                    trace_func = TraceFunction('torch.broadcast_tensors', False, prefix='fuse_').parse_args(
                        node.prev_tensors[0], node.prev_tensors[1]
                    )
                next_tensors = torch.broadcast_tensors(node.prev_tensors[0], node.prev_tensors[1])
                graph.insert_before(node, trace_func, next_tensors, False, new_indices)

        for name in fuse_mapping:
            if name in self.layerwise_config:
                del self.layerwise_config[name]

        for name in self.layerwise_config:
            node = graph.nodes_map[name]
            if isinstance(node.module, nn.Module) and not isinstance(
                node.module,
                (nn.Dropout, nn.AdaptiveAvgPool2d, nn.AvgPool2d, nn.MaxPool2d, nn.ReLU, nn.Upsample, nn.ConstantPad2d),
            ):
                self.effective_layers.append(name)
            elif node.kind() in ('add', 'mul', 'cat') or node.kind() in FUNCTIONAL_MODULE_MAPPING:
                self.effective_layers.append(name)

        graph.quantized = True
        graph.recompute_forward_order()