import copy
import functools
import itertools
import re
import typing
import warnings

import igraph as ig
import numpy as np

from tinynn.util.util import class_conditional, get_logger

from ..schemas.tflite.schema_generated import ActivationFunctionType, Padding
from . import tflite as tfl
from .base import FUSE_ACTIVATION_MAP, ExtendedOperator
from .graph import CommonGraph

log = get_logger(__name__, 'INFO')


class GraphOptimizer(object):
    graph: CommonGraph
    fuse_tensor_count: int
    fuse_attr_count: int
    fuse_quant: bool
    group_conv_rewrite: bool
    tflite_micro_rewrite: bool
    quantize_input_output_type: typing.Optional[str]

    # Optimization levels
    NO_OPTIMIZE: int = 0
    FOLD_BUFFER: int = 1
    FUSE_BN: int = 2
    COMMON_OPTIMIZE: int = 3
    BRANCH_OPTIMIZE: int = 4
    BRANCH_OPTIMIZE_EXTENDED: int = 5
    ALL_OPTIMIZE: int = 5

    def __init__(
        self,
        graph: CommonGraph,
        level: int,
        fuse_quant: bool,
        group_conv_rewrite: bool,
        rewrite_quantizable: bool,
        tflite_micro_rewrite: bool,
        quantize_input_output_type: typing.Optional[str],
        fuse_input_indices: typing.Optional[typing.List[int]] = None,
        fuse_output_indices: typing.Optional[typing.List[int]] = None,
        max_transpose_dims: int = -1,
        bypass_elementwise_passthrough_constraint: bool = False,
        group_tensors: bool = False,
        conv_transpose_with_bias: bool = True,
        hybrid_int16_lstm: bool = False,
    ) -> None:
        self.graph = graph
        self.fuse_tensor_count = 0
        self.fuse_attr_count = 0
        self.level = level
        self.fuse_quant = fuse_quant
        self.group_conv_rewrite = group_conv_rewrite
        self.rewrite_quantizable = rewrite_quantizable
        self.tflite_micro_rewrite = tflite_micro_rewrite
        self.quantize_input_output_type = quantize_input_output_type
        self.fuse_input_indices = fuse_input_indices
        self.fuse_output_indices = fuse_output_indices
        self.max_transpose_dims = max_transpose_dims
        self.bypass_elementwise_passthrough_constraint = bypass_elementwise_passthrough_constraint
        self.group_tensors = group_tensors
        self.conv_transpose_with_bias = conv_transpose_with_bias
        self.hybrid_int16_lstm = hybrid_int16_lstm

    def create_attr_tensor(
        self, tensor: tfl.Tensor, name: str = None, quantization: typing.Optional[tfl.QuantizationParameters] = None
    ):
        if name is None:
            if self.fuse_attr_count == 0:
                name = 'fuse_attr'
            else:
                name = f'fuse_attr_{self.fuse_attr_count}'
            self.fuse_attr_count += 1
        return tfl.Tensor(tensor, name, has_buffer=True, quantization=quantization)

    def create_transform_tensor(
        self, tensor: tfl.Tensor, name: str = None, quantization: typing.Optional[tfl.QuantizationParameters] = None
    ):
        if name is None:
            if self.fuse_tensor_count == 0:
                name = 'fuse_transform'
            else:
                name = f'fuse_transform_{self.fuse_tensor_count}'
            self.fuse_tensor_count += 1
        return tfl.Tensor(tensor, name, has_buffer=False, quantization=quantization)

    @class_conditional(lambda self: self.level >= GraphOptimizer.FUSE_BN)
    def fuse_conv_fc_bn(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(functools.partial(is_bn_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        actions = []
        for conv, bn, tensor in filtered_pairs:
            bn_activ = bn['op'].fusedActivationFunction
            conv_activ = getattr(conv['op'], 'fusedActivationFunction', None)
            if conv_activ is None and bn_activ != ActivationFunctionType.NONE:
                continue

            # Find out the output of the batch-norm nodes
            new_output = bn['outputs'][0]
            assert new_output in self.graph.tensor_map

            # For each node that is next of a batch-norm node, we connect it with the conv node
            self.graph.connect_next_tensors(bn, conv, new_output)

            # Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
            # batch-norm instead
            conv['outputs'][0] = new_output
            conv['op'].outputs[0] = self.graph.tensor_map[new_output]
            self.graph.tensor_node_map[new_output] = conv['name']
            tensor['name'] = bn['outputs'][0]
            tensor['label'] = bn['outputs'][0]

            if bn_activ != ActivationFunctionType.NONE and conv_activ == ActivationFunctionType.NONE:
                conv['op'].fusedActivationFunction = bn_activ

            # Collect the arguments of the conv and batch-norm nodes
            weight = conv['op'].inputs[1]
            bias = conv['op'].inputs[2] if len(conv['op'].inputs) > 2 else None
            bn_w, bn_b, bn_mean, bn_var = bn['op'].inputs[1:]
            bn_w, bn_b, bn_mean, bn_var = (
                bn_w.tensor.copy(),
                bn_b.tensor.copy(),
                bn_mean.tensor.copy(),
                bn_var.tensor.copy(),
            )
            activ_w = weight.tensor.copy()
            activ_b = bias.tensor.copy() if bias is not None else None
            eps = bn['op'].eps

            # Fuse conv/fc and batch-norm
            new_weight = fuse_bn_weight(
                eps, bn_w, bn_var, activ_w, conv['node_type'] == ExtendedOperator.GENERIC_DECONV
            )
            new_bias = fuse_bn_bias(eps, bn_w, bn_var, bn_mean, bn_b, activ_b)

            # New attribute tensors
            new_w = self.create_attr_tensor(new_weight)
            new_b = self.create_attr_tensor(new_bias)

            # Collect the actions we should take here
            # The reason that we don't do the actions here is because we are currently in the loop of vertices,
            # the iterator will be invalidated once `replace_operator_input` is called
            actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
            if bias is not None:
                actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))
            else:
                actions.append((self.graph.append_operator_input, (conv, new_b)))

            remove_ids.append(bn.index)

        # Process actions
        for func, args in actions:
            func(*args)

        # Delete batch-norm nodes
        for id in remove_ids:
            vertex = self.graph.graph.vs[id]
            assert vertex['node_type'] == ExtendedOperator.BATCH_NORM
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.FUSE_BN)
    def fuse_bn_conv(self):
        edges = self.graph.graph.es.select(functools.partial(is_rev_bn_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)

        def _remove_last_pred(seq):
            bn = seq[0]
            conv = seq[1]

            # Collect the arguments of the conv and batch-norm nodes
            weight = conv['op'].inputs[1]
            bias = conv['op'].inputs[2] if len(conv['op'].inputs) > 2 else None
            bn_w, bn_b, bn_mean, bn_var = bn['op'].inputs[1:]
            bn_w, bn_b, bn_mean, bn_var = (
                bn_w.tensor.copy(),
                bn_b.tensor.copy(),
                bn_mean.tensor.copy(),
                bn_var.tensor.copy(),
            )
            activ_w = weight.tensor.copy()
            activ_b = bias.tensor.copy() if bias is not None else None
            eps = bn['op'].eps

            new_weight = fuse_rev_bn_weight(eps, bn_w, bn_var, activ_w)
            new_bias = fuse_rev_bn_bias(eps, bn_w, bn_var, bn_mean, bn_b, activ_b, activ_w)

            return False, (conv, bias, new_weight, new_bias)

        def _remove_last_action(first_node, last_node, custom_data):
            conv, bias, new_weight, new_bias = custom_data

            new_w = self.create_attr_tensor(new_weight)
            new_b = self.create_attr_tensor(new_bias)

            actions = []
            actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
            if bias is not None:
                actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))
            else:
                actions.append((self.graph.append_operator_input, (conv, new_b)))
            return actions

        def _skip_pred(seq):
            bn = seq[0]['op']
            conv = seq[1]['op']

            skip = bn.inputs[0].quantization is not None or (
                conv.inputs[1].shape[1] == 1 and conv.inputs[1].shape[0] == conv.groups and conv.groups > 1
            )
            return skip

        elinimate_sequences(
            self.graph,
            filtered_pairs,
            True,
            None,
            _remove_last_pred,
            _remove_last_action,
            _skip_pred,
            force_forward_input=True,
        )

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_activation(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(functools.partial(is_activ_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        for pre_activ, activ, tensor in filtered_pairs:
            if not self.conv_transpose_with_bias and pre_activ['node_type'] == ExtendedOperator.GENERIC_DECONV:
                continue

            # Find out the output of the batch-norm nodes
            new_output = activ['outputs'][0]
            assert new_output in self.graph.tensor_map

            # For each node that is next of the activation node, we connect it with the previous node
            self.graph.connect_next_tensors(activ, pre_activ, new_output)

            # Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
            # batch-norm instead
            pre_activ['outputs'][0] = new_output
            pre_activ['op'].outputs[0] = self.graph.tensor_map[new_output]
            self.graph.tensor_node_map[new_output] = pre_activ['name']
            tensor['name'] = activ['outputs'][0]
            tensor['label'] = activ['outputs'][0]

            # Fuse activation
            pre_activ['op'].fusedActivationFunction = FUSE_ACTIVATION_MAP[activ['node_type']]

            remove_ids.append(activ.index)

        # Delete activation nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_same_padding(self):
        edges = self.graph.graph.es.select(functools.partial(is_padding_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)

        def _remove_last_pred(seq):
            op = seq[1]['op']
            return False, op

        def _remove_last_action(first_node, last_node, custom_data):
            op = custom_data
            op.padding = Padding.SAME
            return []

        def _skip_pred(seq):
            pad_op = seq[0]['op']
            next_op = seq[1]['op']

            input_shape = pad_op.inputs[0].shape[1:-1]
            if seq[1]['node_type'] == ExtendedOperator.MAX_POOL_2D:
                kernel_shape = (next_op.filterHeight, next_op.filterWidth)
                strides = (next_op.strideH, next_op.strideW)
                dilation = (1, 1)
            elif seq[1]['node_type'] in (
                ExtendedOperator.CONV_2D,
                ExtendedOperator.DEPTHWISE_CONV_2D,
            ):
                kernel_shape = next_op.inputs[1].shape[1:-1]
                strides = (next_op.strideH, next_op.strideW)
                dilation = (next_op.dilationHFactor, next_op.dilationWFactor)
            elif seq[1]['node_type'] == ExtendedOperator.CONV_3D:
                kernel_shape = next_op.inputs[1].shape[:3]
                strides = (next_op.strideD, next_op.strideH, next_op.strideW)
                dilation = (next_op.dilationDFactor, next_op.dilationHFactor, next_op.dilationWFactor)

            pad_args = get_same_padding_args(input_shape, kernel_shape, strides, dilation)
            pad_arr = np.array(pad_args, dtype='int32')

            old_pad_arr = pad_op.inputs[1].tensor
            skip = not np.array_equal(pad_arr, old_pad_arr)

            return skip

        elinimate_sequences(
            self.graph,
            filtered_pairs,
            True,
            None,
            _remove_last_pred,
            _remove_last_action,
            _skip_pred,
            force_forward_input=True,
        )

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_same_padding_slicing(self):
        edges = self.graph.graph.es.select(functools.partial(is_slicing_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        actions = []
        for prev_node, slice_node, tensor in filtered_pairs:
            prev_op = prev_node['op']
            slice_op = slice_node['op']

            input_shape = slice_op.outputs[0].shape[1:-1]
            if prev_node['node_type'] == ExtendedOperator.TRANSPOSE_CONV:
                kernel_shape = prev_op.inputs[1].shape[1:-1]
                strides = (prev_op.strideH, prev_op.strideW)
                dilation = (1, 1)
            elif prev_node['node_type'] == ExtendedOperator.CONV_3D_TRANSPOSE:
                kernel_shape = prev_op.inputs[1].shape[:3]
                strides = (prev_op.strideD, prev_op.strideH, prev_op.strideW)
                dilation = (prev_op.dilationDFactor, prev_op.dilationHFactor, prev_op.dilationWFactor)

            pad_args = get_same_padding_args(input_shape, kernel_shape, strides, dilation)
            pad_arr = np.array(pad_args, dtype='int32')

            start_arr = [x for x in slice_op.inputs[1].tensor]
            end_arr = [slice_op.inputs[0].shape[i] - x - slice_op.outputs[0].shape[i] for i, x in enumerate(start_arr)]

            old_pad_args = [[x, y] for x, y in zip(start_arr, end_arr)]
            skip = not np.array_equal(pad_arr, old_pad_args)

            if skip:
                continue

            # Find out the output of the slice nodes
            new_output = slice_node['outputs'][0]
            assert new_output in self.graph.tensor_map

            # For each node that is next of the slice_nodeation node, we connect it with the previous node
            self.graph.connect_next_tensors(slice_node, prev_node, new_output)

            # Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
            # slice op instead
            prev_node['outputs'][0] = new_output
            prev_node['op'].outputs[0] = self.graph.tensor_map[new_output]
            self.graph.tensor_node_map[new_output] = prev_node['name']
            tensor['name'] = slice_node['outputs'][0]
            tensor['label'] = slice_node['outputs'][0]

            # Fuse padding
            prev_node['op'].padding = Padding.SAME

            new_shape = np.array(prev_node['op'].outputs[0].shape, dtype='int32')
            new_shape_tensor = self.create_attr_tensor(new_shape)
            actions.append((self.graph.replace_operator_input, (prev_node, 0, new_shape_tensor)))

            remove_ids.append(slice_node.index)

        for func, args in actions:
            func(*args)

        # Delete activation nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_requantize(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(
            functools.partial(is_requantize_fusable_edge, graph_converter=self.graph.graph)
        )
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        for pre_activ, activ, tensor in filtered_pairs:
            if pre_activ.outdegree() > 1:
                skip = False
                pre_quantize = None
                for out_edge in pre_activ.out_edges():
                    next_node = self.graph.graph.vs[out_edge.target]
                    while True:
                        if next_node['node_type'] == ExtendedOperator.QUANTIZE:
                            if pre_quantize is None:
                                pre_quantize = next_node['op'].outputs[0].quantization
                            else:
                                cur_quantize = next_node['op'].outputs[0].quantization
                                if (
                                    pre_quantize.scale != cur_quantize.scale
                                    or pre_quantize.zero_point != cur_quantize.zero_point
                                    or pre_quantize.dim != cur_quantize.dim
                                ):
                                    skip = True
                            break
                        elif next_node['node_type'] == ExtendedOperator.DEQUANTIZE:
                            break
                        elif next_node['node_type'] in (ExtendedOperator.RESHAPE, ExtendedOperator.TRANSPOSE):
                            if next_node.outdegree() > 1:
                                skip = True
                                break
                            else:
                                next_node = self.graph.graph.vs[next_node.out_edges()[0].target]
                        else:
                            skip = True
                            break

                    if skip:
                        break

                if skip:
                    continue

                # Find out the output of the first node in the sequence
                output_name = activ['op'].inputs[0].name
                output_idx = pre_activ['outputs'].index(output_name)
                new_output = pre_activ['outputs'][output_idx]
                assert new_output in self.graph.tensor_map

                # For each node that is next of the last node, we connect it with the first node
                # Also, the replace the tensors when needed
                self.graph.replace_next_tensors(activ, pre_activ, new_output)

                new_tensor = pre_activ['op'].outputs[0]
                old_tensor = activ['op'].outputs[0]
                new_tensor.quantization = old_tensor.quantization
            else:
                # Find out the output of the batch-norm nodes
                new_output = activ['outputs'][0]
                assert new_output in self.graph.tensor_map

                # For each node that is next of the activation node, we connect it with the previous node
                self.graph.connect_next_tensors(activ, pre_activ, new_output)

                # Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
                # batch-norm instead
                pre_activ['outputs'][0] = new_output
                pre_activ['op'].outputs[0] = self.graph.tensor_map[new_output]
                self.graph.tensor_node_map[new_output] = pre_activ['name']
                tensor['name'] = activ['outputs'][0]
                tensor['label'] = activ['outputs'][0]

            remove_ids.append(activ.index)

        # Delete activation nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_reciprocal_sqrt(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(functools.partial(is_reciprocal_sqrt_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        for sqrt, div, tensor in filtered_pairs:
            sqrt['node_type'] = ExtendedOperator.RSQRT
            sqrt['op'] = tfl.RsqrtOperator(sqrt['op'].inputs, sqrt['op'].outputs)

            div_op = div['op']
            if (
                div_op.inputs[0].buffer is not None
                and np.all(div_op.inputs[0].tensor == 1.0)
                and div['op'].fusedActivationFunction == ActivationFunctionType.NONE
            ):
                new_output = div['outputs'][0]
                assert new_output in self.graph.tensor_map

                # For each node that is next of the div node, we connect it with the previous node
                self.graph.connect_next_tensors(div, sqrt, new_output)

                # Update graph, prepare to drop the output tensor of the div node and use the output tensor of the
                # sqrt instead
                sqrt['outputs'][0] = new_output
                sqrt['op'].outputs[0] = self.graph.tensor_map[new_output]
                self.graph.tensor_node_map[new_output] = sqrt['name']
                tensor['name'] = div['outputs'][0]
                tensor['label'] = div['outputs'][0]

                # remove div op
                remove_ids.append(div.index)
            else:
                div['node_type'] = ExtendedOperator.MUL
                div['op'] = tfl.MulOperator(
                    div['op'].inputs, div['op'].outputs, fusedActivationFunction=div['op'].fusedActivationFunction
                )

        # Delete div nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def remove_tile_before_binary_elementwise_ops(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(functools.partial(is_tile_binary_op_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        actions = []
        binary_op_ids = set()
        for tile, op_node, tensor in filtered_pairs:
            tile_op = tile['op']
            binary_op = op_node['op']

            input_idx = None
            for i in range(2):
                try:
                    _ = tile['outputs'].index(binary_op.inputs[i].name)
                    input_idx = i
                    break
                except ValueError:
                    pass

            if input_idx is None:
                continue

            alter_input_idx = 1 - input_idx
            try:
                out_shape = np.broadcast_shapes(binary_op.inputs[alter_input_idx].shape, tile_op.inputs[0].shape)
                if out_shape != binary_op.outputs[0].shape:
                    continue
            except ValueError:
                continue

            if op_node.index not in binary_op_ids:
                binary_op_ids.add(op_node.index)
            else:
                continue

            new_tensor = tile_op.inputs[0]

            # Replace input tensors
            actions.append((self.graph.replace_operator_input, (op_node, input_idx, new_tensor)))

            # remove tile op
            remove_ids.append(tile.index)

        # Process actions
        for func, args in actions:
            func(*args)
        # Delete tile nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_conv2d_gather(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(functools.partial(is_conv2d_gather_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        actions = []
        for conv, gather, tensor in filtered_pairs:
            # Find out the output of the batch-norm nodes
            new_output = gather['outputs'][0]
            assert new_output in self.graph.tensor_map

            # For each node that is next of the activation node, we connect it with the previous node
            self.graph.connect_next_tensors(gather, conv, new_output)

            # Update graph, prepare to drop the output tensor of the gather node and use the output tensor of the
            # conv instead
            conv['outputs'][0] = new_output
            conv_out_quant_param = conv['op'].outputs[0].quantization
            conv['op'].outputs[0] = self.graph.tensor_map[new_output]
            conv['op'].outputs[0].quantization = conv_out_quant_param
            self.graph.tensor_node_map[new_output] = conv['name']
            tensor['name'] = gather['outputs'][0]
            tensor['label'] = gather['outputs'][0]
            # permute weight of conv-op
            indx = gather['op'].inputs[1].tensor.copy()
            w = conv['op'].inputs[1].tensor.copy()
            w_quant_param = conv['op'].inputs[1].quantization
            new_w = np.take(w, indx, axis=0)
            # permute bias of conv-op
            b = conv['op'].inputs[2].tensor.copy() if len(conv['op'].inputs) > 2 else None
            b_quant_param = conv['op'].inputs[2].quantization
            new_b = np.take(b, indx, axis=0) if b is not None else None
            if w_quant_param is not None and isinstance(w_quant_param.scale, list) and w_quant_param.dim == 0:
                new_w_scale = np.take(w_quant_param.scale, indx, axis=0)
                new_w_zeros = np.take(w_quant_param.zero_point, indx, axis=0)
                w_quant_param.scale = new_w_scale
                w_quant_param.zero_point = new_w_zeros
                if new_b is not None:
                    new_b_scale = np.take(b_quant_param.scale, indx, axis=0)
                    new_b_zeros = np.take(b_quant_param.zero_point, indx, axis=0)
                    b_quant_param.scale = new_b_scale
                    b_quant_param.zero_point = new_b_zeros
            new_w = self.create_attr_tensor(new_w, quantization=w_quant_param)
            actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
            new_b = self.create_attr_tensor(new_b, quantization=b_quant_param)
            actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))

            # remove gather op
            remove_ids.append(gather.index)

        # Process actions
        for func, args in actions:
            func(*args)
        # Delete activation nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_gather_conv2d(self):
        # Find fusable ops
        edges = self.graph.graph.es.select(functools.partial(is_gather_conv2d_edge, graph_converter=self.graph.graph))
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)

        def _remove_last_pred(seq):
            gather = seq[0]
            conv = seq[1]
            return False, (gather, conv)

        def _remove_last_action(first_node, last_node, custom_data):
            gather, conv = custom_data

            actions = []

            indx = np.argsort(gather['op'].inputs[1].tensor)
            w = conv['op'].inputs[1].tensor.copy()
            w_quant_param = conv['op'].inputs[1].quantization
            new_w = np.take(w, indx, axis=3)
            if w_quant_param is not None and isinstance(w_quant_param.scale, list) and w_quant_param.dim == 3:
                new_w_scale = np.take(w_quant_param.scale, indx, axis=0)
                new_w_zeros = np.take(w_quant_param.zero_point, indx, axis=0)
                w_quant_param.scale = new_w_scale
                w_quant_param.zero_point = new_w_zeros
            new_w = self.create_attr_tensor(new_w, quantization=w_quant_param)
            actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
            return actions

        elinimate_sequences(
            self.graph,
            filtered_pairs,
            True,
            None,
            _remove_last_pred,
            _remove_last_action,
            False,
            force_forward_input=True,
        )

    @class_conditional(lambda self: self.tflite_micro_rewrite)
    def split_requantize(self):
        vertices = self.graph.graph.vs.select(functools.partial(is_requantize_node, graph_converter=self.graph.graph))

        remove_ids = []
        ops = []
        restore_mapping = []
        for quantize in vertices:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in quantize.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in quantize['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            remove_ids.append(quantize.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        for quantize, mapping in zip(sorted_ops, restore_mapping):
            input_tensor = quantize.inputs[0]
            output_tensor = quantize.outputs[0]

            intermediate = self.create_transform_tensor(input_tensor.tensor.astype('float32'))

            ops.append(tfl.DequantizeOperator([input_tensor], [intermediate]))
            ops.append(tfl.QuantizeOperator([intermediate], [output_tensor]))

            for op in ops:
                self.graph.add_operator(op, transform=True)

            self.graph.try_restore_edges(mapping)

    def transform_graph(self):
        # Find transformable ops
        filtered_nodes = self.graph.graph.vs.select(
            functools.partial(is_transformable_node, graph_converter=self.graph.graph)
        )
        remove_ids = []
        ops = []
        restore_mapping = []
        for node in filtered_nodes:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in node.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in node['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            ops.append(node)
            remove_ids.append(node.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [node['op'] for node in sorted(ops, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        # Do transformation
        for op, mapping in zip(sorted_ops, restore_mapping):
            op.transform(self.graph, mapping)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_simple_transpose_pass(self):
        edges = self.graph.graph.es.select(
            functools.partial(is_transpose_fusable_edge, graph_converter=self.graph.graph)
        )
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        # Try to fuse the edges
        filtered_pairs = fuse_connected_edges(filtered_pairs)

        def _remove_first_pred(seq):
            new_perm = fuse_transpose_perms(seq)

            hints = set()
            for node in seq:
                if 'direction' in node['op'].extra_hints:
                    hints.add(node['op'].extra_hints['direction'])

            if len(hints) == 1:
                hint = next(iter(hints))
            else:
                hint = None

            remove_first = np.array_equal(new_perm, np.sort(new_perm))
            return remove_first, (new_perm, hint)

        def _remove_first_action(first_node, last_node, custom_data):
            # Set fused perm to the first transpose node
            new_perm, hint = custom_data
            if hint is None:
                if 'direction' in first_node['op'].extra_hints:
                    del first_node['op'].extra_hints['direction']
            else:
                first_node['op'].extra_hints['direction'] = hint
            new_perm_tensor = self.create_attr_tensor(new_perm)
            action = (self.graph.replace_operator_input, (first_node, 1, new_perm_tensor))
            return [action]

        elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_simple_gather_pass(self):
        edges = self.graph.graph.es.select(functools.partial(is_gather_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        # Try to fuse the edges
        filtered_pairs = fuse_connected_edges(filtered_pairs)

        def _remove_first_pred(seq):
            new_perm = fuse_transpose_perms(seq)

            hints = set()
            for node in seq:
                if 'direction' in node['op'].extra_hints:
                    hints.add(node['op'].extra_hints['direction'])

            if len(hints) == 1:
                hint = next(iter(hints))
            else:
                hint = None

            remove_first = np.array_equal(new_perm, np.sort(new_perm))
            return remove_first, (new_perm, hint)

        def _remove_first_action(first_node, last_node, custom_data):
            # Set fused perm to the first transpose node
            new_perm, hint = custom_data
            if hint is None:
                if 'direction' in first_node['op'].extra_hints:
                    del first_node['op'].extra_hints['direction']
            else:
                first_node['op'].extra_hints['direction'] = hint
            new_perm_tensor = self.create_attr_tensor(new_perm)
            action = (self.graph.replace_operator_input, (first_node, 1, new_perm_tensor))
            return [action]

        def _skip_pred(seq):
            for node in seq:
                op = node['op']
                idx_tensor = op.inputs[1]
                if idx_tensor.buffer is None:
                    return True
                if len(idx_tensor.shape) > 1:
                    return True
            return False

        elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action, skip_pred=_skip_pred)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_dequant_quant_pass(self, q_first):
        edges = self.graph.graph.es.select(
            functools.partial(is_dequant_quant_fusable_edge, graph_converter=self.graph.graph, q_first=q_first)
        )
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        r_edges = self.graph.graph.es.select(
            functools.partial(is_dequant_quant_fusable_edge, graph_converter=self.graph.graph, q_first=not q_first)
        )
        r_filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in r_edges]

        filtered_pairs = fuse_connected_edges(filtered_pairs + r_filtered_pairs)

        new_pairs = []
        for seq in filtered_pairs:
            start_idx = 0
            end_idx = len(seq)

            if q_first:
                if seq[0]['node_type'] != ExtendedOperator.QUANTIZE:
                    start_idx += 1
                if seq[-1]['node_type'] != ExtendedOperator.DEQUANTIZE:
                    end_idx -= 1
            else:
                if seq[0]['node_type'] != ExtendedOperator.DEQUANTIZE:
                    start_idx += 1
                if seq[-1]['node_type'] != ExtendedOperator.QUANTIZE:
                    end_idx -= 1

            new_seq = seq[start_idx:end_idx]
            if len(new_seq) >= 2:
                new_pairs.append(new_seq)

        filtered_pairs = new_pairs

        def _remove_first_pred(seq):
            first_node, last_node = seq[0], seq[-1]
            new_qparams = last_node['op'].outputs[0].quantization
            orig_qparams = first_node['op'].inputs[0].quantization

            if (
                first_node['node_type'] == ExtendedOperator.DEQUANTIZE
                and last_node['node_type'] == ExtendedOperator.QUANTIZE
            ):
                assert new_qparams is not None
                assert orig_qparams is not None

                remove_first = (
                    new_qparams.scale == orig_qparams.scale
                    and new_qparams.zero_point == orig_qparams.zero_point
                    and new_qparams.dim == orig_qparams.dim
                )
            else:
                assert new_qparams is None
                assert orig_qparams is None

                remove_first = True

            return remove_first, None

        def _remove_first_action(first_node, last_node, custom_data):
            # Set new node type to first node
            first_node['node_type'] = ExtendedOperator.QUANTIZE
            old_op = first_node['op']
            first_node['op'] = tfl.QuantizeOperator(old_op.inputs, old_op.outputs)
            return []

        elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_simple_reshape_pass(self):
        edges = self.graph.graph.es.select(functools.partial(is_reshape_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        # Try to fuse the edge
        filtered_pairs = fuse_connected_edges(filtered_pairs)

        def _remove_first_pred(seq):
            first_node, last_node = seq[0], seq[-1]
            new_shape = last_node['op'].inputs[1].tensor
            orig_shape = np.array(first_node['op'].inputs[0].shape, dtype='int32')

            hints = set()
            for node in seq:
                if 'direction' in node['op'].extra_hints:
                    hints.add(node['op'].extra_hints['direction'])

            if len(hints) == 1:
                hint = next(iter(hints))
            else:
                hint = None

            remove_first = np.array_equal(new_shape, orig_shape)
            return remove_first, (new_shape, hint)

        def _remove_first_action(first_node, last_node, custom_data):
            # Set final shape to the first reshape node
            new_shape, hint = custom_data
            if hint is None:
                if 'direction' in first_node['op'].extra_hints:
                    del first_node['op'].extra_hints['direction']
            else:
                first_node['op'].extra_hints['direction'] = hint
            new_shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
            first_node['op'].newShape = new_shape_tensor.tensor
            action = (self.graph.replace_operator_input, (first_node, 1, new_shape_tensor))
            return [action]

        elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_simple_slice_pass(self):
        edges = self.graph.graph.es.select(functools.partial(is_slice_fusable_edge, graph_converter=self.graph.graph))
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        # Try to fuse the edge
        filtered_pairs = fuse_connected_edges(filtered_pairs)

        def _remove_first_pred(seq):
            fused_info = fuse_slices(seq)

            return False, fused_info

        def _remove_first_action(first_node, last_node, custom_data):
            # Set final shape to the first reshape node
            start, end, stride = custom_data
            if all((x == 1 for x in stride)):
                target_class = tfl.SliceOperator
                target_type = ExtendedOperator.SLICE
            else:
                target_class = tfl.StridedSliceOperator
                target_type = ExtendedOperator.STRIDED_SLICE

            if target_type == ExtendedOperator.SLICE:
                size = end - start
                start_tensor = self.create_attr_tensor(np.array(start, dtype='int32'))
                size_tensor = self.create_attr_tensor(np.array(size, dtype='int32'))
                actions = [
                    (self.graph.replace_operator_input, (first_node, 1, start_tensor)),
                    (self.graph.replace_operator_input, (first_node, 2, size_tensor)),
                ]
                if first_node['node_type'] != ExtendedOperator.SLICE:
                    old_slice_op = first_node['op']
                    first_node['node_type'] = ExtendedOperator.SLICE
                    first_node['op'] = target_class(old_slice_op.inputs, old_slice_op.outputs)
                    actions.append((self.graph.remove_operator_input, (first_node, 3)))
            else:
                size = end - start
                start_tensor = self.create_attr_tensor(np.array(start, dtype='int32'))
                end_tensor = self.create_attr_tensor(np.array(end, dtype='int32'))
                stride_tensor = self.create_attr_tensor(np.array(stride, dtype='int32'))
                if first_node['node_type'] == ExtendedOperator.STRIDED_SLICE:
                    actions = [
                        (self.graph.replace_operator_input, (first_node, 1, start_tensor)),
                        (self.graph.replace_operator_input, (first_node, 2, end_tensor)),
                        (self.graph.replace_operator_input, (first_node, 3, stride_tensor)),
                    ]
                else:
                    old_slice_op = first_node['op']
                    first_node['node_type'] = ExtendedOperator.STRIDED_SLICE
                    first_node['op'] = target_class(old_slice_op.inputs, old_slice_op.outputs)
                    actions = [
                        (self.graph.replace_operator_input, (first_node, 1, start_tensor)),
                        (self.graph.replace_operator_input, (first_node, 2, end_tensor)),
                        (self.graph.append_operator_input, (first_node, stride_tensor)),
                    ]

            return actions

        elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)

    @class_conditional(lambda self: self.group_tensors)
    def group_tensors_pass(self):
        tensor_map = {}
        actions = []
        bytes_saved = 0
        tensors_saved = 0
        for v in self.graph.graph.vs:
            if v['node_type'] == ExtendedOperator.CONSTANT_NODE:
                tensor = self.graph.tensor_map[v['outputs'][0]]
                if tensor.quantization is None:
                    t_idx = (tensor.buffer.data, tensor.dtype, tensor.shape)
                else:
                    scale = tensor.quantization.scale
                    zero_point = tensor.quantization.zero_point
                    if isinstance(scale, list):
                        scale = tuple(scale)
                    if isinstance(zero_point, list):
                        zero_point = tuple(zero_point)
                    t_idx = (
                        tensor.buffer.data,
                        tensor.dtype,
                        tensor.shape,
                        scale,
                        zero_point,
                        tensor.quantization.dim,
                    )

                if t_idx in tensor_map:
                    new_tensor = tensor_map[t_idx]
                    for e in v.out_edges():
                        target = e.target_vertex
                        if target['op'] is not None:
                            for i, inp in enumerate(target['op'].inputs):
                                if inp.name == tensor.name:
                                    log.debug(f'{inp.name} used in {target["outputs"][0]}:{i} -> {new_tensor.name}')
                                    tensors_saved += 1
                                    bytes_saved += len(inp.buffer.data)
                                    actions.append((self.graph.replace_operator_input, (target, i, new_tensor)))
                else:
                    tensor_map[t_idx] = tensor

        # Process actions
        for func, args in actions:
            func(*args)

        log.info(f'{tensors_saved} duplicated tensors found, {bytes_saved / 1024 / 1024:.2f} MB saved')

    def cleanup_dead_nodes(self):
        cleanup_nodes = []
        if not self.graph.graph.is_connected('weak'):
            while True:
                for vertex in self.graph.graph.vs:
                    if (
                        vertex['node_type'] not in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE)
                        and vertex.outdegree() == 0
                    ):
                        if vertex['node_type'] == ExtendedOperator.INPUT_NODE:
                            continue
                        if vertex['node_type'] != ExtendedOperator.CONSTANT_NODE:
                            if vertex['op'] is None or vertex['op'].extra_hints.get('warn_on_unused', True):
                                warnings.warn('Non constant node removed, something must be wrong there')
                                log.warning('-' * 30)
                                log.warning('Info of the deleted node:')
                                log.warning(f'vertex: {vertex}')
                                # edge = self.graph.graph.es.select(name=vertex['outputs'][0])
                                # assert edge is None, (
                                #     f'The edge {vertex["outputs"][0]} exists but the connection to the vertex'
                                #     f' {vertex["name"]} is broken, probably there have some conflicts in the names'
                                #     ' of the nodes'
                                # )
                        cleanup_nodes.append(vertex.index)

                if len(cleanup_nodes) == 0:
                    break

                self.graph.graph.delete_vertices(cleanup_nodes)
                cleanup_nodes.clear()

    @class_conditional(lambda self: self.level >= GraphOptimizer.FOLD_BUFFER)
    def fold_transpose_buffer(self):
        edges = self.graph.graph.es.select(
            functools.partial(is_constant_transpose_fusable_edge, graph_converter=self.graph.graph)
        )
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        for constant, transpose, tensor in filtered_pairs:
            # Calculate the output of the transposed constant nodes
            constant_tensor = transpose['op'].inputs[0].tensor
            perm_tensor = transpose['op'].inputs[1].tensor
            new_constant = np.transpose(constant_tensor, perm_tensor)
            new_tensor = self.create_attr_tensor(new_constant, quantization=transpose['op'].outputs[0].quantization)
            new_node = self.graph.add_nodes([new_tensor])[0]

            # For each node that is next of a constant transpose node, we connect it with the new constant node
            for out_edge in transpose.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                self.graph.graph.add_edge(new_node, next_node, name=new_tensor.name, label=new_tensor.name)
                log.debug(
                    f'NEW EDGE: {new_node["label"]} -> {next_node["label"]} {self.graph.tensor_map[out_edge["name"]]}'
                )
                op = next_node['op']
                for idx in range(len(op.inputs)):
                    if op.inputs[idx].name == transpose['op'].outputs[0].name:
                        op.inputs[idx] = new_tensor

            remove_ids.append(transpose.index)

        # Delete constant transpose nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def transpose_to_reshape_pass(self):
        filtered_nodes = self.graph.graph.vs.select(
            functools.partial(is_transformable_transpose_node, graph_converter=self.graph.graph)
        )

        # Collect actions for the transformable transpose nodes
        actions = []
        for node in filtered_nodes:
            original_op = node['op']
            output_shape = np.array(original_op.outputs[0].shape, dtype='int32')
            shape_tensor = self.create_attr_tensor(output_shape)
            new_op = tfl.ReshapeOperator(original_op.inputs, original_op.outputs, output_shape)
            node['op'] = new_op
            node['node_type'] = ExtendedOperator.RESHAPE
            node['label'] = new_op.type_name()
            actions.append((self.graph.replace_operator_input, (node, 1, shape_tensor)))

        # Process actions
        for func, args in actions:
            node = args[0]
            func(*args)

    @class_conditional(lambda self: self.level >= GraphOptimizer.FOLD_BUFFER)
    def fold_reshape_buffer(self):
        edges = self.graph.graph.es.select(
            functools.partial(is_constant_reshape_fusable_edge, graph_converter=self.graph.graph)
        )
        filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

        remove_ids = []
        for constant, reshape, tensor in filtered_pairs:
            # Calculate the output of the transposed constant nodes
            constant_tensor = reshape['op'].inputs[0].tensor
            shape_tensor = reshape['op'].inputs[1].tensor
            new_constant = np.reshape(constant_tensor, shape_tensor)
            new_tensor = self.create_attr_tensor(new_constant, quantization=reshape['op'].inputs[0].quantization)
            new_node = self.graph.add_nodes([new_tensor])[0]

            # For each node that is next of a constant transpose node, we connect it with the new constant node
            for out_edge in reshape.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                self.graph.graph.add_edge(new_node, next_node, name=new_tensor.name, label=new_tensor.name)
                log.debug(
                    f'NEW EDGE: {new_node["label"]} -> {next_node["label"]} {self.graph.tensor_map[out_edge["name"]]}'
                )
                op = next_node['op']
                for idx in range(len(op.inputs)):
                    if op.inputs[idx].name == reshape['op'].outputs[0].name:
                        op.inputs[idx] = new_tensor

            remove_ids.append(reshape.index)

        # Delete constant transpose nodes
        self.graph.graph.delete_vertices(remove_ids)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def remove_noop_pass(self, branch: bool = False):
        edges = self.graph.graph.es.select(
            functools.partial(is_ending_with_noop_edge, graph_converter=self.graph.graph, branch=branch)
        )
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        # Try to fuse the edges
        if not branch:
            filtered_pairs = fuse_connected_edges(filtered_pairs)

        elinimate_sequences(self.graph, filtered_pairs)

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_wrapped_reshape_within_transpose_pass(self):
        edges = self.graph.graph.es.select(
            functools.partial(is_wrapped_reshape_within_transpose_edge, graph_converter=self.graph.graph)
        )
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        # Try to fuse the edges
        fused_pairs = fuse_connected_edges(filtered_pairs)

        # Only TRANSPOSE->RESHAPE->TRANSPOSE is supported here
        filtered_pairs = []
        for seq in fused_pairs:
            seq_len = len(seq)
            transpose_first = seq[0]['node_type'] == ExtendedOperator.TRANSPOSE
            if seq_len >= 3 and transpose_first:
                filtered_pairs.append(seq[:3])
            elif seq_len >= 4:
                filtered_pairs.append(seq[1:4])

        def _skip_pred(seq):
            mid_node = seq[1]
            orig_shape = mid_node['op'].inputs[0].shape
            new_shape = mid_node['op'].outputs[0].shape

            if not is_simple_reshape(orig_shape, new_shape):
                return True

            new_perm = fuse_transpose_perms_extended(seq)
            return (new_perm != np.sort(new_perm)).any()

        def _remove_last_pred(seq):
            orig_tensor = seq[0]['op'].inputs[0].tensor
            return False, (seq[2], orig_tensor)

        def _remove_last_action(first_node, last_node, custom_data):
            # Set final shape to the first reshape node
            last_trans, orig_tensor = custom_data
            actions = []
            original_op = last_trans['op']
            output_shape = np.array(original_op.outputs[0].shape, dtype='int32')
            shape_tensor = self.create_attr_tensor(output_shape)
            new_op = tfl.ReshapeOperator(original_op.inputs, original_op.outputs, output_shape)
            last_trans['op'] = new_op
            last_trans['node_type'] = ExtendedOperator.RESHAPE
            last_trans['label'] = new_op.type_name()
            new_op.inputs[0].tensor = orig_tensor
            new_op.inputs[0].shape = new_op.inputs[0].tensor.shape
            actions.append((self.graph.replace_operator_input, (last_trans, 1, shape_tensor)))
            return actions

        elinimate_sequences(self.graph, filtered_pairs, True, None, _remove_last_pred, _remove_last_action, _skip_pred)

    @class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE)
    def branch_reshape_expand_pass(self):
        edges = self.graph.graph.es.select(functools.partial(is_reshape_branch_edge, graph_converter=self.graph.graph))
        branch_reshape_nodes = list(set(self.graph.graph.vs[edge.source] for edge in edges))

        def _new_reshape(node: ig.Vertex, prev_node: ig.Vertex, next_node: ig.Vertex):
            actions = []

            op = node['op']
            op_out = op.outputs[0]
            op_shape = op.inputs[1]

            prev_idx = prev_node['outputs'].index(op.inputs[0].name)
            if prev_node['node_type'] == ExtendedOperator.INPUT_NODE:
                prev_out = self.graph.tensor_map[op.inputs[0].name]
            else:
                prev_op = prev_node['op']
                prev_out = prev_op.outputs[prev_idx]

            new_tensor = self.create_transform_tensor(op_out.tensor.copy(), quantization=op_out.quantization)
            new_shape = self.create_attr_tensor(op_shape.tensor.copy())
            new_op = tfl.ReshapeOperator([prev_out, new_shape], [new_tensor], new_shape.tensor)
            new_op.extra_hints.update(op.extra_hints)
            self.graph.add_operator(new_op)

            next_indices = []
            for i, t in enumerate(next_node['op'].inputs):
                if t.name == op_out.name:
                    actions.append((self.graph.replace_operator_input, (next_node, i, new_tensor)))
                    next_indices.append(i)

            assert len(next_indices) > 0, f'{op_out.name} not in {[t.name for t in next_node["op"].inputs]}'

            return actions

        expand_op_outputs_in_branches(branch_reshape_nodes, _new_reshape, self.graph)

    @class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE)
    def branch_transpose_expand_pass(self):
        edges = self.graph.graph.es.select(
            functools.partial(is_transpose_branch_edge, graph_converter=self.graph.graph)
        )
        branch_transpose_nodes = list(set(self.graph.graph.vs[edge.source] for edge in edges))

        def _new_transpose(node: ig.Vertex, prev_node: ig.Vertex, next_node: ig.Vertex):
            actions = []

            op = node['op']
            op_out = op.outputs[0]
            op_perm = op.inputs[1]

            prev_idx = prev_node['outputs'].index(op.inputs[0].name)
            if prev_node['node_type'] in (ExtendedOperator.INPUT_NODE, ExtendedOperator.CONSTANT_NODE):
                prev_out = self.graph.tensor_map[op.inputs[0].name]
            else:
                prev_op = prev_node['op']
                prev_out = prev_op.outputs[prev_idx]

            new_tensor = self.create_transform_tensor(op_out.tensor.copy(), quantization=op_out.quantization)
            new_perm = self.create_attr_tensor(op_perm.tensor.copy())
            new_op = tfl.TransposeOperator([prev_out, new_perm], [new_tensor])
            new_op.extra_hints.update(op.extra_hints)
            self.graph.add_operator(new_op)

            next_indices = []
            for i, t in enumerate(next_node['op'].inputs):
                if t.name == op_out.name:
                    actions.append((self.graph.replace_operator_input, (next_node, i, new_tensor)))
                    next_indices.append(i)

            assert len(next_indices) > 0, f'{op_out.name} not in {[t.name for t in next_node["op"].inputs]}'

            return actions

        expand_op_outputs_in_branches(branch_transpose_nodes, _new_transpose, self.graph)

    @class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE, 0)
    def elementwise_reshape_transpose_passthrough_pass(self) -> int:
        edges = self.graph.graph.es.select(
            functools.partial(is_transpose_reshape_op_edge, graph_converter=self.graph.graph)
        )
        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        processed_nodes = set()
        num_actions = 0
        for node in unique_nodes:
            pending_processed_nodes = set()

            op = node['op']
            input_indices = op_input_indices(op)
            l_shape = op.inputs[0].shape
            r_shape = op.outputs[0].shape
            if len(l_shape) == 0 or len(r_shape) == 0:
                continue
            l_map, r_map, _, _ = reshape_mapping(l_shape, r_shape)
            mode = None
            need_chain = False
            for l_val, r_val in zip(l_map, r_map):
                if len(l_val) > 1 and len(r_val) == 1:
                    if mode in (None, 'up'):
                        mode = 'up'
                    else:
                        mode = '?'
                        break
                elif len(r_val) > 1 and len(l_val) == 1:
                    if mode in (None, 'down'):
                        mode = 'down'
                    else:
                        mode = '?'
                        break
                elif len(r_val) > 1 and len(l_val) > 1:
                    if len(r_val) != len(l_val) or r_val != l_val:
                        # TODO: Support this case
                        mode = '?'
                        break
                    else:
                        need_chain = True

            if mode is None:
                mode = 'down'

            # TODO: Support multi-multi mappings
            if mode == '?':
                # reset hints if passthrough is not possible
                for i in input_indices:
                    prev_node_name = op.inputs[i].name
                    prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                    if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                        if 'direction' in prev_node['op'].extra_hints:
                            prev_node['op'].extra_hints.pop('direction')
                for edge in node.out_edges():
                    if edge.index in remove_edges:
                        continue
                    next_node = self.graph.graph.vs[edge.target]

                    if 'direction' in next_node['op'].extra_hints:
                        next_node['op'].extra_hints.pop('direction')
                continue

            check_consecutive_indices = []
            if need_chain:
                new_l_map = []
                new_r_map = []
                for l_val, r_val in zip(l_map, r_map):
                    if len(l_val) > 1 and len(r_val) > 1:
                        if mode == 'down':
                            check_consecutive_indices.append(l_val)
                        else:
                            check_consecutive_indices.append(r_val)
                        for l_item in l_val:
                            new_l_map.append([l_item])
                        for r_item in r_val:
                            new_r_map.append([r_item])
                    else:
                        new_l_map.append(l_val)
                        new_r_map.append(r_val)

                l_map = new_l_map
                r_map = new_r_map

            prev_nodes = []
            cand_perms = dict()
            cand_rev_perms = dict()
            prev_output_indices = []
            num_constant_nodes = 0
            prev_hints = set()
            skip = False
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    if prev_node['name'] in processed_nodes:
                        skip = True
                        break
                    pending_processed_nodes.add(prev_node['name'])
                    if mode == 'down':
                        perm = tuple(prev_node['op'].inputs[1].tensor.tolist())
                        cand_perms.setdefault(perm, 0)
                        cand_perms[perm] += 1
                    elif mode == 'up':
                        perm = tuple(np.argsort(prev_node['op'].inputs[1].tensor).tolist())
                        cand_rev_perms.setdefault(perm, 0)
                        cand_rev_perms[perm] += 1
                    if 'direction' in prev_node['op'].extra_hints:
                        prev_hints.add(prev_node['op'].extra_hints['direction'])

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    num_constant_nodes += 1

            if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints):
                continue

            next_nodes = []
            next_edges = []
            out_nodes = []
            next_hints = set()
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                else:
                    if next_node['name'] in processed_nodes:
                        skip = True
                        break
                    pending_processed_nodes.add(next_node['name'])
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    if mode == 'down':
                        perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())
                        cand_rev_perms.setdefault(perm, 0)
                        cand_rev_perms[perm] += 1
                    elif mode == 'up':
                        perm = tuple(next_node['op'].inputs[1].tensor.tolist())
                        cand_perms.setdefault(perm, 0)
                        cand_perms[perm] += 1
                    if 'direction' in next_node['op'].extra_hints:
                        next_hints.add(next_node['op'].extra_hints['direction'])

            if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints):
                continue

            cur_transpose_size = sum(cand_perms.values()) + sum(cand_rev_perms.values())
            new_transpose_size = len(prev_nodes) + len(next_nodes) - sum(cand_perms.values()) - num_constant_nodes

            # Skip if the number of transpose nodes is not decreasing
            if len(cand_perms) == 0 or len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
                continue
            elif new_transpose_size == cur_transpose_size:
                skip = True
                if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
                    if 'down' in prev_hints or 'up' in next_hints:
                        skip = False

            if skip:
                continue

            perm = max(cand_perms.items(), key=lambda x: x[1])[0]
            perm_arr = np.array(perm, dtype='int32')

            skip = False
            for check_idx in check_consecutive_indices:
                if mode == 'down':
                    target_idx = perm_arr[check_idx]
                elif mode == 'up':
                    perm_sorter = perm_arr.argsort()
                    target_idx = perm_sorter[np.searchsorted(perm_arr, check_idx, sorter=perm_sorter)]
                normalized_src = [x - check_idx[0] for x in check_idx]
                normalized_tgt = [x - target_idx[0] for x in target_idx]
                if normalized_src != normalized_tgt:
                    skip = True
                    break

            if skip:
                continue

            num_actions += 1

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for pending_processed_node in pending_processed_nodes:
                processed_nodes.add(pending_processed_node)

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            if mode == 'down':
                inv_perm_arr = np.argsort(perm_arr).astype('int32')
                l_dict = dict(zip([x[0] for x in l_map], r_map))
                indices = map(lambda x: l_dict[x], inv_perm_arr.tolist())
                inv_post_perm = list(itertools.chain.from_iterable(indices))
                inv_post_perm_arr = np.array(inv_post_perm, dtype='int32')
                post_perm_arr = np.argsort(inv_post_perm_arr).astype('int32')
            elif mode == 'up':
                r_dict = dict(zip([x[0] for x in r_map], l_map))
                indices = map(lambda x: r_dict[x], perm)
                inv_perm = list(itertools.chain.from_iterable(indices))
                inv_perm_arr = np.array(inv_perm, dtype='int32')
                post_perm_arr = np.argsort(perm_arr).astype('int32')
                inv_post_perm_arr = np.argsort(post_perm_arr).astype('int32')

            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                perm_tensor = self.create_attr_tensor(inv_perm_arr)
                prev_new_out = self.create_transform_tensor(
                    np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
                )
                transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
                transpose_op.extra_hints['direction'] = 'up'
                self.graph.add_operator(transpose_op)
                actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                perm_tensor = self.create_attr_tensor(post_perm_arr)
                new_out = self.create_transform_tensor(
                    np.transpose(op_out.tensor, inv_post_perm_arr), quantization=op_out.quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
                transpose_op.extra_hints['direction'] = 'down'
                self.graph.add_operator(transpose_op)

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

            # OP specific dim handling logic
            old_shape = op.inputs[1].tensor
            new_shape = self.create_attr_tensor(old_shape[inv_post_perm_arr])
            actions.append((self.graph.replace_operator_input, (node, 1, new_shape, True)))
            op.newShape = new_shape.tensor

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

        return num_actions

    @class_conditional(lambda self: self.rewrite_quantizable)
    def elementwise_op_quantize_passthrough_pass(self):
        edges = self.graph.graph.es.select(
            functools.partial(
                is_quantize_elementwise_op_edge, graph_converter=self.graph.graph, with_lstm=self.hybrid_int16_lstm
            )
        )
        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.DEQUANTIZE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        for node in unique_nodes:
            op = node['op']
            input_indices = op_input_indices(op)

            prev_nodes = []
            q_tensors = dict()
            prev_output_indices = []
            skip_names = []
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.DEQUANTIZE:
                    q_tensors[prev_node_name] = prev_node['op'].inputs[0]

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    if (
                        node['node_type'] in (ExtendedOperator.MINIMUM, ExtendedOperator.MAXIMUM)
                        and i != 0
                        and prev_node_name not in self.graph.q_mapping
                    ):
                        f_tensor = self.graph.tensor_map[prev_node_name]
                        r_tensor = q_tensors[op.inputs[0].name]
                        q_arr = np.rint(
                            f_tensor.tensor / r_tensor.quantization.scale + r_tensor.quantization.zero_point
                        )
                        i_type = np.iinfo(r_tensor.tensor.dtype)

                        if np.any(q_arr > i_type.max):
                            warnings.warn('Overflow while quantizing the tensor')
                            q_arr = np.minimum(q_arr, i_type.max)

                        if np.any(q_arr < i_type.min):
                            warnings.warn('Underflow while quantizing the tensor')
                            q_arr = np.maximum(q_arr, i_type.min)

                        q_arr = q_arr.astype(r_tensor.dtype)
                        q_tensor = self.create_attr_tensor(q_arr, quantization=r_tensor.quantization)
                        self.graph.q_mapping[prev_node_name] = q_tensor

                    if prev_node_name in self.graph.q_mapping:
                        skip_names.append(prev_node_name)

            next_nodes = []
            next_edges = []
            out_nodes = []
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                else:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.QUANTIZE:
                    skip = False
                    name = next_node['op'].inputs[0].name
                    q_tensor = next_node['op'].outputs[0]
                    assert q_tensor.quantization is not None
                    if node['node_type'] in (
                        ExtendedOperator.BATCH_MATMUL,
                        ExtendedOperator.ABS,
                        ExtendedOperator.RSQRT,
                    ):
                        if q_tensor.dtype not in (np.dtype('int8'), np.dtype('int16')):
                            skip = True
                    elif node['node_type'] == ExtendedOperator.DIV:
                        if q_tensor.dtype != np.dtype('uint8'):
                            skip = True
                    elif node['node_type'] == ExtendedOperator.SOFTMAX:
                        if q_tensor.dtype == np.dtype('int8'):
                            if (
                                abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
                                or q_tensor.quantization.zero_point != -128
                            ):
                                skip = True
                        elif q_tensor.dtype == np.dtype('int16'):
                            if (
                                abs(q_tensor.quantization.scale - 1.0 / 32768) > 0.001 * 1.0 / 32768
                                or q_tensor.quantization.zero_point != 0
                            ):
                                skip = True
                        elif q_tensor.dtype == np.dtype('uint8'):
                            if (
                                abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
                                or q_tensor.quantization.zero_point != 0
                            ):
                                log.warning(
                                    'On some chips, only softmax with scale=1.0/256 and zero_point=0 is supported'
                                )
                        else:
                            skip = True
                    elif node['node_type'] == ExtendedOperator.LOG_SOFTMAX:
                        if q_tensor.dtype == np.dtype('int8'):
                            if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 127:
                                skip = True
                        elif q_tensor.dtype == np.dtype('uint8'):
                            if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 255:
                                skip = True
                        else:
                            skip = True

                    if not skip:
                        q_tensors[name] = q_tensor

            cur_transpose_size = len(q_tensors)
            new_transpose_size = len(prev_nodes) + len(next_nodes) - len(skip_names)

            # Skip if the number of [de]quantize nodes is not decreasing
            if len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
                continue

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    if prev_out.name in skip_names:
                        prev_new_out = self.graph.q_mapping[prev_out.name]
                        self.graph.add_nodes([prev_new_out])
                        tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                        actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
                    else:
                        prev_new_out = self.create_transform_tensor(
                            q_tensors[prev_out.name].tensor, quantization=q_tensors[prev_out.name].quantization
                        )
                        tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                        self.graph.add_operator(tfl.QuantizeOperator([prev_out], [prev_new_out]))
                        actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                new_out = self.create_transform_tensor(
                    q_tensors[op_out.name].tensor, quantization=q_tensors[op_out.name].quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                self.graph.add_operator(tfl.DequantizeOperator([new_out], [op_out]))

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

    @class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE, 0)
    def elementwise_op_transpose_passthrough_pass(self, quantizable_ops_only: bool = False) -> int:
        edges = self.graph.graph.es.select(
            functools.partial(
                is_transpose_elementwise_op_edge,
                graph_converter=self.graph.graph,
                quantizable_ops_only=quantizable_ops_only,
            )
        )

        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        if quantizable_ops_only:
            all_edges = self.graph.graph.es.select(
                functools.partial(
                    is_transpose_elementwise_op_edge,
                    graph_converter=self.graph.graph,
                    quantizable_ops_only=False,
                )
            )

            all_pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in all_edges)

            forward_d = dict(all_pairs)
            backward_d = {v: k for k, v in forward_d.items()}

            filtered_nodes = []
            for s, e in pairs:
                if s['node_type'] == ExtendedOperator.TRANSPOSE:
                    pn = backward_d.get(s, None)
                    if pn is not None:
                        filtered_nodes.append(pn)
                    else:
                        log.warning('Cannot passthrough transpose upward around requantizable ops')
                else:
                    pn = forward_d.get(e, None)
                    if pn is not None:
                        filtered_nodes.append(pn)
                    else:
                        log.warning('Cannot passthrough transpose downward around requantizable ops')
        else:
            filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        num_actions = 0
        for node in unique_nodes:
            op = node['op']
            input_indices = op_input_indices(op)

            prev_nodes = []
            cand_perms = dict()
            prev_output_indices = []
            num_constant_nodes = 0
            num_reshape_transpose = 0
            prev_hints = set()
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    perm = tuple(prev_node['op'].inputs[1].tensor.tolist())

                    if node['node_type'] == ExtendedOperator.PACK:
                        perm = [i if i < op.axis else i + 1 for i in perm]
                        perm.insert(op.axis, op.axis)
                        perm = tuple(perm)

                    cand_perms.setdefault(perm, 0)
                    cand_perms[perm] += 1
                    if 'direction' in prev_node['op'].extra_hints:
                        prev_hints.add(prev_node['op'].extra_hints['direction'])

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    num_constant_nodes += 1

                if prev_node['node_type'] == ExtendedOperator.RESHAPE:
                    prev_prev_node_name = self.graph.tensor_node_map[prev_node['op'].inputs[0].name]
                    prev_prev_node = self.graph.graph.vs.find(name=prev_prev_node_name)
                    if prev_prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                        num_reshape_transpose += 1
                        if 'direction' in prev_prev_node['op'].extra_hints:
                            prev_hints.add(prev_prev_node['op'].extra_hints['direction'])

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
                continue

            next_nodes = []
            next_edges = []
            out_nodes = []
            skip_names = []
            next_hints = set()
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
                    skip_names.append(edge['label'])
                else:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())

                    if node['node_type'] == ExtendedOperator.UNPACK:
                        perm = [i if i < op.axis else i + 1 for i in perm]
                        perm.insert(op.axis, op.axis)
                        perm = tuple(perm)

                    cand_perms.setdefault(perm, 0)
                    cand_perms[perm] += 1
                    if 'direction' in next_node['op'].extra_hints:
                        next_hints.add(next_node['op'].extra_hints['direction'])

                if next_node['node_type'] == ExtendedOperator.RESHAPE:
                    o_nodes = [e.target_vertex for e in next_node.out_edges()]
                    if len(o_nodes) == 1 and o_nodes[0]['node_type'] == ExtendedOperator.TRANSPOSE:
                        num_reshape_transpose += 1
                        if 'direction' in o_nodes[0]['op'].extra_hints:
                            next_hints.add(o_nodes[0]['op'].extra_hints['direction'])

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
                continue

            cur_transpose_size = sum(cand_perms.values()) + num_reshape_transpose
            new_transpose_size = (
                len(prev_nodes) + len(next_nodes) - num_constant_nodes - cur_transpose_size + num_reshape_transpose
            )

            # Skip if the following conditions are met
            # a. the number of transpose nodes is not decreasing (skip if `bypass_elementwise_passthrough_constraint`)
            # b. no hint can be found (skip if optimize level is below BRANCH_OPTIMIZE_EXTENDED)
            is_increasing = new_transpose_size > cur_transpose_size
            is_not_decreasing = new_transpose_size >= cur_transpose_size
            is_same = new_transpose_size == cur_transpose_size
            if len(next_nodes) == 0:
                continue
            else:
                if self.bypass_elementwise_passthrough_constraint:
                    condition = is_not_decreasing
                else:
                    if is_increasing:
                        continue
                    condition = is_same

                if condition:
                    skip = True
                    if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
                        if 'down' in prev_hints or 'up' in next_hints:
                            skip = False
                    if skip:
                        continue

            num_actions += 1

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            perm = max(cand_perms.items(), key=lambda x: x[1])[0]
            perm_arr = np.array(perm, dtype='int32')
            inv_perm_arr = np.argsort(perm_arr).astype('int32')

            if node['node_type'] == ExtendedOperator.UNPACK:
                inv_perm_arr_post = inv_perm_arr[inv_perm_arr != op.axis]
                inv_perm_arr_post[inv_perm_arr_post > op.axis] -= 1

                perm_arr_post = np.argsort(inv_perm_arr_post).astype('int32')
            elif node['node_type'] == ExtendedOperator.PACK:
                perm_arr_post = perm_arr
                inv_perm_arr_post = inv_perm_arr

                perm_arr = perm_arr_post[perm_arr_post != op.axis]
                perm_arr[perm_arr > op.axis] -= 1

                inv_perm_arr = np.argsort(perm_arr).astype('int32')
            else:
                perm_arr_post = perm_arr
                inv_perm_arr_post = inv_perm_arr

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    perm_tensor = self.create_attr_tensor(inv_perm_arr)
                    if len(prev_out.shape) != perm_tensor.tensor.size:
                        new_shape = [1] * (perm_tensor.tensor.size - len(prev_out.shape)) + list(prev_out.shape)
                        prev_out_reshaped = self.create_transform_tensor(
                            np.reshape(prev_out.tensor, new_shape), quantization=prev_out.quantization
                        )
                        new_shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
                        self.graph.add_operator(
                            tfl.ReshapeOperator([prev_out, new_shape_tensor], [prev_out_reshaped], new_shape)
                        )
                        prev_out = prev_out_reshaped
                    prev_new_out = self.create_transform_tensor(
                        np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
                    )
                    tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                    transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
                    transpose_op.extra_hints['direction'] = 'up'
                    self.graph.add_operator(transpose_op)
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                # For unused tensors, we perform inplace shape updates
                if op_out.name in skip_names:
                    orig_shape = np.array(op_out.shape, dtype='int32')
                    new_shape = orig_shape[inv_perm_arr]
                    op_out.shape = tuple(new_shape.tolist())
                    continue

                perm_tensor = self.create_attr_tensor(perm_arr_post)
                new_out = self.create_transform_tensor(
                    np.transpose(op_out.tensor, inv_perm_arr_post), quantization=op_out.quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
                transpose_op.extra_hints['direction'] = 'down'
                self.graph.add_operator(transpose_op)

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

            # OP specific dim handling logic
            if node['node_type'] in (ExtendedOperator.CONCATENATION, ExtendedOperator.GATHER, ExtendedOperator.UNPACK):
                old_axis = op.axis
                new_axis = np.where(inv_perm_arr == old_axis)[0][0]
                op.axis = new_axis
            elif node['node_type'] == ExtendedOperator.PACK:
                old_axis = op.axis
                new_axis = np.where(inv_perm_arr_post == old_axis)[0][0]
                op.axis = new_axis
            elif node['node_type'] == ExtendedOperator.SPLIT_V:
                old_dim = op.inputs[2].tensor
                new_dim = np.where(inv_perm_arr == old_dim)[0][0]
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
            elif node['node_type'] == ExtendedOperator.SPLIT:
                old_dim = op.inputs[0].tensor
                new_dim = np.where(inv_perm_arr == old_dim)[0][0]
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
            elif node['node_type'] in (
                ExtendedOperator.PAD,
                ExtendedOperator.PADV2,
                ExtendedOperator.MIRROR_PAD,
                ExtendedOperator.TILE,
            ):
                old_pad = op.inputs[1].tensor
                new_pad = self.create_attr_tensor(old_pad[inv_perm_arr])
                actions.append((self.graph.replace_operator_input, (node, 1, new_pad, True)))
            elif node['node_type'] == ExtendedOperator.PRELU:
                old_weight = op.inputs[1].tensor
                if old_weight.ndim != 1:
                    assert old_weight.ndim + 1 == len(inv_perm_arr)
                    new_perm = np.argsort(np.argsort(inv_perm_arr[1:]))
                    new_perm_t = self.create_attr_tensor(np.array(new_perm, dtype='int32'))
                    new_weight = self.create_transform_tensor(np.transpose(old_weight, new_perm))
                    self.graph.add_operator(tfl.TransposeOperator([op.inputs[1], new_perm_t], [new_weight]))
                    actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
            elif node['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE):
                for i, t in enumerate(op.inputs[1:]):
                    if t.buffer is None:
                        new_perm_t = self.create_attr_tensor(np.array(inv_perm_arr, dtype='int32'))
                        new_t = self.create_transform_tensor(t.tensor[inv_perm_arr])
                        self.graph.add_operator(tfl.TransposeOperator([t, new_perm_t], [new_t]))
                    else:
                        new_t = self.create_attr_tensor(t.tensor[inv_perm_arr])
                    actions.append((self.graph.replace_operator_input, (node, i + 1, new_t, True)))
            elif node['node_type'] in (
                ExtendedOperator.SUM,
                ExtendedOperator.ARG_MIN,
                ExtendedOperator.ARG_MAX,
                ExtendedOperator.REDUCE_MIN,
                ExtendedOperator.REDUCE_MAX,
                ExtendedOperator.REDUCE_PROD,
                ExtendedOperator.MEAN,
            ):
                old_axis = op.inputs[1].tensor.tolist()
                new_axis = []
                for t in old_axis:
                    new_t = np.where(inv_perm_arr == t)[0][0]
                    new_axis.append(new_t)
                axis_arr = np.array(new_axis, dtype='int32')
                axis_tensor = self.create_attr_tensor(axis_arr)
                actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

        return num_actions

    @class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE, 0)
    def elementwise_op_reshape_passthrough_pass(self) -> int:
        edges = self.graph.graph.es.select(
            functools.partial(is_reshape_elementwise_op_edge, graph_converter=self.graph.graph)
        )
        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.RESHAPE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        num_actions = 0
        for node in unique_nodes:
            op = node['op']
            dim_indice = op_input_dims(op)
            input_indices = op_input_indices(op)

            prev_nodes = []
            cand_shapes = dict()
            cand_next_shapes = dict()
            prev_output_indices = []
            num_constant_nodes = 0
            prev_hints = set()
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    num_constant_nodes += 1

                if prev_node['node_type'] == ExtendedOperator.RESHAPE:
                    mapping = dict()
                    if not is_simple_reshape(
                        prev_node['op'].inputs[0].shape, prev_node['op'].outputs[0].shape, mapping
                    ):
                        continue

                    new_dim = None
                    if dim_indice is not None:
                        rev_mapping = {v: k for k, v in mapping.items()}
                        if node['node_type'] == ExtendedOperator.PACK:
                            if dim_indice in rev_mapping:
                                tmp_new_dim = rev_mapping[dim_indice]
                            else:
                                if dim_indice - 1 in rev_mapping:
                                    tmp_new_dim = rev_mapping[dim_indice - 1] + 1
                                elif dim_indice + 1 in rev_mapping:
                                    tmp_new_dim = rev_mapping[dim_indice + 1] - 1
                                else:
                                    # TODO: Figure out the rev index
                                    tmp_new_dim = -1
                            tmp_dim_indice = dim_indice
                            new_dim = -1
                            dim_indice = -1
                        else:
                            if dim_indice not in rev_mapping:
                                continue
                            new_dim = rev_mapping[dim_indice]

                    shape = tuple(prev_node['op'].inputs[0].shape)
                    shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
                    if node['node_type'] == ExtendedOperator.PACK and tmp_new_dim >= 0:
                        shape = list(shape)
                        shape.insert(tmp_new_dim, -1)
                        shape = tuple(shape)
                    cand_shapes.setdefault(shape, 0)
                    cand_shapes[shape] += 1

                    next_shape = tuple(prev_node['op'].outputs[0].shape)
                    next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
                    if node['node_type'] == ExtendedOperator.PACK:
                        next_shape = list(next_shape)
                        next_shape.insert(tmp_dim_indice, -1)
                        next_shape = tuple(next_shape)
                    cand_next_shapes.setdefault(next_shape, 0)
                    cand_next_shapes[next_shape] += 1

                    if node['node_type'] == ExtendedOperator.PACK:
                        dim_indice = tmp_dim_indice

                    if 'direction' in prev_node['op'].extra_hints:
                        prev_hints.add(prev_node['op'].extra_hints['direction'])

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
                continue

            next_nodes = []
            next_edges = []
            out_nodes = []
            skip_names = []
            next_hints = set()
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
                    skip_names.append(edge['label'])
                else:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.RESHAPE:
                    mapping = dict()
                    if not is_simple_reshape(
                        next_node['op'].inputs[0].shape, next_node['op'].outputs[0].shape, mapping
                    ):
                        continue

                    new_dim = None
                    if dim_indice is not None:
                        if node['node_type'] == ExtendedOperator.UNPACK:
                            if dim_indice in mapping:
                                tmp_new_dim = mapping[dim_indice]
                            else:
                                if dim_indice - 1 in mapping:
                                    tmp_new_dim = mapping[dim_indice - 1] + 1
                                elif dim_indice + 1 in mapping:
                                    tmp_new_dim = mapping[dim_indice + 1] - 1
                                else:
                                    # TODO: Figure out the rev index
                                    tmp_new_dim = -1
                            tmp_dim_indice = dim_indice
                            new_dim = -1
                            dim_indice = -1
                        else:
                            if dim_indice not in mapping:
                                continue
                            new_dim = mapping[dim_indice]

                    shape = tuple(next_node['op'].outputs[0].shape)
                    shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
                    if node['node_type'] == ExtendedOperator.UNPACK and tmp_new_dim >= 0:
                        shape = list(shape)
                        shape.insert(tmp_new_dim, -1)
                        shape = tuple(shape)
                    cand_shapes.setdefault(shape, 0)
                    cand_shapes[shape] += 1

                    next_shape = tuple(next_node['op'].inputs[0].shape)
                    next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
                    if node['node_type'] == ExtendedOperator.UNPACK:
                        next_shape = list(next_shape)
                        next_shape.insert(tmp_dim_indice, -1)
                        next_shape = tuple(next_shape)
                    cand_next_shapes.setdefault(next_shape, 0)
                    cand_next_shapes[next_shape] += 1

                    if node['node_type'] == ExtendedOperator.UNPACK:
                        dim_indice = tmp_dim_indice

                    if 'direction' in next_node['op'].extra_hints:
                        next_hints.add(next_node['op'].extra_hints['direction'])

            if len(cand_shapes) == 0:
                continue

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
                continue

            cur_reshape_size = max(cand_shapes.values())
            cur_next_reshape_size = max(cand_next_shapes.values())
            full_size = len(prev_nodes) + len(next_nodes)

            if cur_reshape_size != cur_next_reshape_size:
                continue

            new_reshape_size = full_size - cur_reshape_size - num_constant_nodes

            # Skip if not wrapped by reshapes
            if (
                len(next_nodes) == 0 or new_reshape_size > cur_reshape_size
            ):  # cur_reshape_size < full_size or cur_next_reshape_size < full_size:
                continue
            elif new_reshape_size == cur_reshape_size:
                skip = True
                if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
                    if 'down' in prev_hints or 'up' in next_hints:
                        skip = False
                if skip:
                    continue

            num_actions += 1

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            prev_shape = max(cand_shapes.items(), key=lambda x: x[1])[0]
            next_shape = max(cand_next_shapes.items(), key=lambda x: x[1])[0]

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    if node['node_type'] == ExtendedOperator.PACK:
                        tmp_prev_shape = prev_shape
                        prev_shape = [i for i in prev_shape if i != -1]
                    prev_shape_aligned = prev_shape
                    if np.prod(prev_out.shape) != np.prod(prev_shape):
                        new_prev_shape = prev_out.shape
                        if len(prev_out.shape) < len(next_shape):
                            new_prev_shape = [1] * (len(next_shape) - len(prev_out.shape)) + list(prev_out.shape)
                        mapping = {}
                        is_simple_reshape(prev_shape, next_shape, mapping)
                        prev_shape_aligned = np.ones(len(prev_shape), dtype='int32')
                        for pi, ni in mapping.items():
                            prev_shape_aligned[pi] = new_prev_shape[ni]

                    prev_new_out = self.create_transform_tensor(
                        np.reshape(prev_out.tensor, prev_shape_aligned), quantization=prev_out.quantization
                    )
                    tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                    shape_tensor = self.create_attr_tensor(np.array(prev_new_out.shape, dtype='int32'))
                    reshape_op = tfl.ReshapeOperator(
                        [prev_out, shape_tensor], [prev_new_out], newShape=shape_tensor.tensor
                    )
                    reshape_op.extra_hints['direction'] = 'up'
                    self.graph.add_operator(reshape_op)
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

                    if node['node_type'] == ExtendedOperator.PACK:
                        prev_shape = tmp_prev_shape

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                if node['node_type'] == ExtendedOperator.UNPACK:
                    tmp_prev_shape = prev_shape
                    prev_shape = [i for i in prev_shape if i != -1]

                # For unused tensors, we perform inplace shape updates
                if op_out.name in skip_names:
                    new_shape = np.reshape(op_out.tensor, prev_shape).shape
                    op_out.shape = tuple(new_shape)

                    if node['node_type'] == ExtendedOperator.UNPACK:
                        prev_shape = tmp_prev_shape

                    continue

                new_out = self.create_transform_tensor(
                    np.reshape(op_out.tensor, prev_shape), quantization=op_out.quantization
                )
                shape_tensor = self.create_attr_tensor(np.array(op_out.shape, dtype='int32'))

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                reshape_op = tfl.ReshapeOperator([new_out, shape_tensor], [op_out], shape_tensor.tensor)
                reshape_op.extra_hints['direction'] = 'down'
                self.graph.add_operator(reshape_op)

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

                if node['node_type'] == ExtendedOperator.UNPACK:
                    prev_shape = tmp_prev_shape

            # OP specific dim handling logic
            if node['node_type'] in (
                ExtendedOperator.CONCATENATION,
                ExtendedOperator.GATHER,
                ExtendedOperator.UNPACK,
                ExtendedOperator.PACK,
            ):
                new_axis = prev_shape.index(-1)
                op.axis = new_axis
            elif node['node_type'] == ExtendedOperator.SPLIT_V:
                new_dim = prev_shape.index(-1)
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
            elif node['node_type'] == ExtendedOperator.SPLIT:
                new_dim = prev_shape.index(-1)
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
            elif node['node_type'] in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD):
                old_pad = op.inputs[1].tensor
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)
                new_pad = np.zeros((len(prev_shape), 2), dtype='int32')
                new_pad[new_dim, :] = old_pad[old_dim, :]
                new_pad_tensor = self.create_attr_tensor(new_pad)
                actions.append((self.graph.replace_operator_input, (node, 1, new_pad_tensor, True)))
            elif node['node_type'] == ExtendedOperator.PRELU:
                old_weight = op.inputs[1].tensor
                if old_weight.ndim != 1:
                    new_dim = prev_shape.index(-1)
                    old_dim = next_shape.index(-1)
                    new_shape = np.ones(len(prev_shape) - 1, dtype='int32')
                    new_shape[new_dim - 1] = old_weight.shape[old_dim - 1]
                    new_shape_t = self.create_attr_tensor(new_shape)
                    new_weight = self.create_transform_tensor(np.reshape(old_weight, new_shape))
                    self.graph.add_operator(tfl.ReshapeOperator([op.inputs[1], new_shape_t], [new_weight], new_shape))
                    actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
            elif node['node_type'] == ExtendedOperator.SLICE:
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)

                new_start = np.zeros(len(prev_shape), dtype='int32')
                new_start[new_dim] = op.inputs[1].tensor[old_dim]
                new_start_t = self.create_attr_tensor(new_start)

                new_size = np.array(prev_shape, dtype='int32')
                new_size[new_dim] = op.inputs[2].tensor[old_dim]
                new_size_t = self.create_attr_tensor(new_size)

                actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
                actions.append((self.graph.replace_operator_input, (node, 2, new_size_t, True)))
            elif node['node_type'] == ExtendedOperator.STRIDED_SLICE:
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)

                new_start = np.zeros(len(prev_shape), dtype='int32')
                new_start[new_dim] = op.inputs[1].tensor[old_dim]
                if op.inputs[1].buffer is None:
                    new_start_t = self.create_transform_tensor(new_start)
                    starts_mid = new_start[new_dim : new_dim + 1]
                    starts_mid_tensor = self.create_transform_tensor(starts_mid)

                    slice_inputs = [
                        op.inputs[1],
                        self.create_attr_tensor(np.array([old_dim], dtype='int32')),
                        self.create_attr_tensor(np.array([1], dtype='int32')),
                    ]

                    self.graph.add_operator(tfl.SliceOperator(slice_inputs, [starts_mid_tensor]))

                    starts_left = new_start[:new_dim]
                    starts_right = new_start[new_dim + 1 :]
                    starts_tensors = []
                    if len(starts_left) > 0:
                        starts_tensors.append(self.create_attr_tensor(starts_left))
                    starts_tensors.append(starts_mid_tensor)
                    if len(starts_right) > 0:
                        starts_tensors.append(self.create_attr_tensor(starts_right))
                    if len(starts_tensors) > 1:
                        self.graph.add_operator(tfl.ConcatenationOperator(starts_tensors, [new_start_t], 0))
                    else:
                        new_start_t = starts_tensors[0]
                else:
                    new_start_t = self.create_attr_tensor(new_start)

                new_end = np.array(prev_shape, dtype='int32')
                new_end[new_dim] = op.inputs[2].tensor[old_dim]
                if op.inputs[2].buffer is None:
                    new_end_t = self.create_transform_tensor(new_end)
                    ends_mid = new_end[new_dim : new_dim + 1]
                    ends_mid_tensor = self.create_transform_tensor(ends_mid)

                    slice_inputs = [
                        op.inputs[2],
                        self.create_attr_tensor(np.array([old_dim], dtype='int32')),
                        self.create_attr_tensor(np.array([1], dtype='int32')),
                    ]

                    self.graph.add_operator(tfl.SliceOperator(slice_inputs, [ends_mid_tensor]))

                    ends_left = new_end[:new_dim]
                    ends_right = new_end[new_dim + 1 :]
                    ends_tensors = []
                    if len(ends_left) > 0:
                        ends_tensors.append(self.create_attr_tensor(ends_left))
                    ends_tensors.append(ends_mid_tensor)
                    if len(ends_right) > 0:
                        ends_tensors.append(self.create_attr_tensor(ends_right))
                    if len(ends_tensors) > 1:
                        self.graph.add_operator(tfl.ConcatenationOperator(ends_tensors, [new_end_t], 0))
                    else:
                        new_end_t = ends_tensors[0]
                else:
                    new_end_t = self.create_attr_tensor(new_end)

                new_stride = np.ones(len(prev_shape), dtype='int32')
                new_stride[new_dim] = op.inputs[3].tensor[old_dim]
                new_stride_t = self.create_attr_tensor(new_stride)

                actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
                actions.append((self.graph.replace_operator_input, (node, 2, new_end_t, True)))
                actions.append((self.graph.replace_operator_input, (node, 3, new_stride_t, True)))
            elif node['node_type'] == ExtendedOperator.TILE:
                old_shape = op.inputs[1].tensor
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)
                new_shape = np.ones(len(prev_shape), dtype='int32')
                new_shape[new_dim] = old_shape[old_dim]
                new_shape_tensor = self.create_attr_tensor(new_shape)
                actions.append((self.graph.replace_operator_input, (node, 1, new_shape_tensor, True)))
            elif node['node_type'] in (
                ExtendedOperator.SUM,
                ExtendedOperator.ARG_MIN,
                ExtendedOperator.ARG_MAX,
                ExtendedOperator.REDUCE_MIN,
                ExtendedOperator.REDUCE_MAX,
                ExtendedOperator.REDUCE_PROD,
                ExtendedOperator.MEAN,
            ):
                new_axis = prev_shape.index(-1)
                axis_arr = np.array([new_axis], dtype='int32')
                axis_tensor = self.create_attr_tensor(axis_arr)
                actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))
            elif dim_indice is not None:
                raise NotImplementedError(f'{node["node_type"]} has the property `dims` but is not handled')

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

        return num_actions

    @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
    def fuse_bmm_add_pass(self):
        edges = self.graph.graph.es.select(functools.partial(is_bmm_add_edge, graph_converter=self.graph.graph))
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
        filtered_pairs = [
            p
            for p in filtered_pairs
            if p[0]['node_type'] != ExtendedOperator.FULLY_CONNECTED
            or len(p[0]['op'].inputs) == 2
            or not np.any(p[0]['op'].inputs[2].tensor)
        ]

        remove_ids = []
        ops = []
        restore_mapping = []
        for bmm, add in filtered_pairs:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in add.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in add['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            ops.append((bmm, add))
            remove_ids.append(bmm.index)
            remove_ids.append(add.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [
            (nodes[0]['op'], nodes[1]['op'])
            for nodes in sorted(ops, key=lambda x: int(re.search(r'\d+', x[1]['name'])[0]))
        ]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        for (bmm, add), mapping in zip(sorted_ops, restore_mapping):
            input_tensor = bmm.inputs[0]
            weight_tensor = bmm.inputs[1]
            bias_tensor = add.inputs[1]
            output_tensor = add.outputs[0]

            ops = []

            if isinstance(bmm, tfl.BatchMatmulOperator):
                weight_t = self.create_transform_tensor(np.transpose(weight_tensor.tensor))
                weight_perm = self.create_attr_tensor(np.array([1, 0], dtype='int32'))
                ops.append(tfl.TransposeOperator([weight_tensor, weight_perm], [weight_t]))
            else:
                weight_t = weight_tensor

            keep_dims = output_tensor.tensor.ndim > 2

            ops.append(
                tfl.FullyConnectedOperator(
                    [input_tensor, weight_t, bias_tensor],
                    [output_tensor],
                    fusedActivationFunction=add.fusedActivationFunction,
                    keepNumDims=keep_dims,
                )
            )

            for op in ops:
                self.graph.add_operator(op, transform=True)

            self.graph.try_restore_edges(mapping)

    @class_conditional(lambda self: self.max_transpose_dims > 0)
    def lower_transpose_dim_pass(self):
        vertices = self.graph.graph.vs.select(
            functools.partial(
                is_high_dim_transpose_node, graph_converter=self.graph.graph, max_transpose_dims=self.max_transpose_dims
            )
        )

        remove_ids = []
        ops = []
        restore_mapping = []
        for trans in vertices:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in trans.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in trans['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            remove_ids.append(trans.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        for trans, mapping in zip(sorted_ops, restore_mapping):
            input_tensor = trans.inputs[0]
            perm_tensor = trans.inputs[1]
            output_tensor = trans.outputs[0]

            input_shape = input_tensor.shape
            perm = perm_tensor.tensor
            output_shape = output_tensor.shape

            last_perm = None
            last_dim = None
            cum_dim = None
            new_shape = []
            new_perm = []
            for d, p in zip(input_shape, perm):
                if last_dim is None and last_perm is None:
                    cum_dim = d
                else:
                    if p - last_perm == 1 or d == 1 or cum_dim == 1:
                        cum_dim *= d
                    else:
                        new_shape.append(cum_dim)
                        new_perm.append(last_perm)
                        cum_dim = d

                last_dim = d
                last_perm = p

            new_shape.append(cum_dim)
            new_perm.append(last_perm)

            new_perm_arr = np.argsort(new_perm).astype('int32')

            assert (
                len(new_shape) <= self.max_transpose_dims
            ), f"Don't know how to reduce the number of dims of transpose with input shape {input_shape}, perm {perm}"

            ops = []

            input_reduced = self.create_transform_tensor(
                np.reshape(input_tensor.tensor, new_shape), quantization=input_tensor.quantization
            )
            reduced_shape = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
            ops.append(tfl.ReshapeOperator([input_tensor, reduced_shape], [input_reduced], new_shape))

            transposed = self.create_transform_tensor(
                np.transpose(input_reduced.tensor, new_perm_arr), quantization=input_tensor.quantization
            )
            new_perm_tensor = self.create_attr_tensor(np.array(new_perm_arr, dtype='int32'))
            ops.append(tfl.TransposeOperator([input_reduced, new_perm_tensor], [transposed]))

            output_shape_tensor = self.create_attr_tensor(np.array(output_shape, dtype='int32'))
            ops.append(tfl.ReshapeOperator([transposed, output_shape_tensor], [output_tensor], output_shape))

            for op in ops:
                self.graph.add_operator(op, transform=True)

            self.graph.try_restore_edges(mapping)

    @class_conditional(lambda self: self.group_conv_rewrite)
    def group_conv_rewrite_pass(self):
        vertices = self.graph.graph.vs.select(functools.partial(is_group_conv_node, graph_converter=self.graph.graph))

        remove_ids = []
        ops = []
        restore_mapping = []
        for conv in vertices:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in conv.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in conv['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            remove_ids.append(conv.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        for conv, mapping in zip(sorted_ops, restore_mapping):
            input_tensor = conv.inputs[0]
            weight_tensor = conv.inputs[1]
            bias_tensor = conv.inputs[2] if len(conv.inputs) > 2 else None
            output_tensor = conv.outputs[0]

            num_input_channel = input_tensor.shape[3]
            num_weight_channel = weight_tensor.shape[3]
            num_chunks = num_input_channel // num_weight_channel

            ops = []

            input_tensors = [
                self.create_transform_tensor(arr, quantization=input_tensor.quantization)
                for arr in np.split(input_tensor.tensor, num_chunks, 3)
            ]
            output_tensors = [
                self.create_transform_tensor(arr, quantization=output_tensor.quantization)
                for arr in np.split(output_tensor.tensor, num_chunks, 3)
            ]
            weights = [
                self.create_attr_tensor(arr, quantization=weight_tensor.quantization)
                for arr in np.split(weight_tensor.tensor, num_chunks, 0)
            ]

            if bias_tensor is not None:
                biases = [
                    self.create_attr_tensor(arr, quantization=bias_tensor.quantization)
                    for arr in np.split(bias_tensor.tensor, num_chunks, 0)
                ]
            else:
                biases = [None] * num_chunks

            dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
            ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))

            for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
                inputs = [it, w]
                if b is not None:
                    inputs.append(b)
                ops.append(
                    tfl.Conv2dOperator(
                        inputs,
                        [ot],
                        strideH=conv.strideH,
                        strideW=conv.strideW,
                        dilationHFactor=conv.dilationHFactor,
                        dilationWFactor=conv.dilationWFactor,
                        fusedActivationFunction=conv.fusedActivationFunction,
                        padding=conv.padding,
                    )
                )

            ops.append(tfl.ConcatenationOperator(output_tensors, [output_tensor], 3))

            for op in ops:
                self.graph.add_operator(op, transform=True)

            self.graph.try_restore_edges(mapping)

    @class_conditional(lambda self: self.group_conv_rewrite)
    def group_deconv_rewrite_pass(self):
        vertices = self.graph.graph.vs.select(functools.partial(is_group_deconv_node, graph_converter=self.graph.graph))

        remove_ids = []
        ops = []
        restore_mapping = []
        for conv in vertices:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in conv.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in conv['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            remove_ids.append(conv.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        for conv, mapping in zip(sorted_ops, restore_mapping):
            input_tensor = conv.inputs[2]
            weight_tensor = conv.inputs[1]
            output_shape_tensor = conv.inputs[0]
            bias_tensor = conv.inputs[3] if len(conv.inputs) > 3 else None
            output_tensor = conv.outputs[0]

            num_output_channel = output_tensor.shape[3]
            num_weight_channel = weight_tensor.shape[0]
            num_chunks = num_output_channel // num_weight_channel

            ops = []

            input_tensors = [
                self.create_transform_tensor(arr, quantization=input_tensor.quantization)
                for arr in np.split(input_tensor.tensor, num_chunks, 3)
            ]
            output_tensors = [
                self.create_transform_tensor(arr, quantization=output_tensor.quantization)
                for arr in np.split(output_tensor.tensor, num_chunks, 3)
            ]
            weights = [
                self.create_attr_tensor(arr, quantization=weight_tensor.quantization)
                for arr in np.split(weight_tensor.tensor, num_chunks, 3)
            ]

            if bias_tensor is not None:
                biases = [
                    self.create_attr_tensor(arr, quantization=bias_tensor.quantization)
                    for arr in np.split(bias_tensor.tensor, num_chunks, 0)
                ]
            else:
                biases = [None] * num_chunks

            new_os = output_shape_tensor.tensor.copy()
            new_os[3] = num_weight_channel
            new_ost = self.create_attr_tensor(new_os)
            dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
            ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))

            for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
                inputs = [new_ost, w, it]
                if b is not None:
                    inputs.append(b)
                ops.append(
                    tfl.TransposeConvOperator(
                        inputs,
                        [ot],
                        padding=conv.padding,
                        strideH=conv.strideH,
                        strideW=conv.strideW,
                    )
                )

            ops.append(tfl.ConcatenationOperator(output_tensors, [output_tensor], 3))

            for op in ops:
                self.graph.add_operator(op, transform=True)

            self.graph.try_restore_edges(mapping)

    @class_conditional(lambda self: self.tflite_micro_rewrite)
    def cat_split_pass(self):
        vertices = self.graph.graph.vs.select(functools.partial(is_large_cat_node, graph_converter=self.graph.graph))

        remove_ids = []
        ops = []
        restore_mapping = []
        for cat in vertices:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in cat.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in cat['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            remove_ids.append(cat.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        for cat, mapping in zip(sorted_ops, restore_mapping):
            input_tensors = cat.inputs
            layer_inputs = input_tensors
            output_tensor = cat.outputs[0]

            axis = cat.axis
            last_layer = False

            ops = []

            while True:
                layer_outputs = []

                while len(layer_inputs) > 0:
                    curr_inputs = layer_inputs[:10]

                    input_arrs = [t.tensor for t in curr_inputs]
                    output_arr = np.concatenate(input_arrs, axis)

                    if last_layer:
                        curr_output = output_tensor
                    else:
                        curr_output = self.create_transform_tensor(output_arr, quantization=output_tensor.quantization)
                        layer_outputs.append(curr_output)

                    ops.append(tfl.ConcatenationOperator(curr_inputs, [curr_output], axis))

                    layer_inputs = layer_inputs[10:]

                if len(layer_outputs) == 0:
                    break
                elif len(layer_outputs) <= 10:
                    last_layer = True

                layer_inputs = layer_outputs

            for op in ops:
                self.graph.add_operator(op, transform=True)

            self.graph.try_restore_edges(mapping)

    def input_transpose_pass(self):
        nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32')
        nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32')

        remove_edges = []
        for name, transpose in zip(self.graph.inputs, self.graph.input_transpose):
            if transpose is True:
                node_name = self.graph.tensor_node_map[name]
                node = self.graph.graph.vs.find(name=node_name)
                assert node['node_type'] == ExtendedOperator.INPUT_NODE

                # For quantized graphs, we insert the transpose op after the quantize op
                next_node = None
                if node.outdegree() == 1:
                    next_node = node.out_edges()[0].target_vertex
                    if next_node['node_type'] != ExtendedOperator.QUANTIZE:
                        next_node = None

                # Transpose input tensor shapes
                input_tensor = self.graph.tensor_map[node['name']]
                input_tensor.tensor = np.transpose(input_tensor.tensor, nchw2nhwc_perm)
                input_tensor.shape = input_tensor.tensor.shape

                # Transpose quantize output tensor shapes
                last_tensor = input_tensor
                last_node = node
                if next_node is not None:
                    last_node = next_node
                    last_tensor = next_node['op'].outputs[0]
                    last_tensor.tensor = np.transpose(last_tensor.tensor, nchw2nhwc_perm)
                    last_tensor.shape = last_tensor.tensor.shape

                # Create new transpose op
                nhwc2nchw_perm_tensor = self.create_attr_tensor(nhwc2nchw_perm)
                transposed = self.create_transform_tensor(
                    np.transpose(last_tensor.tensor, nhwc2nchw_perm), quantization=last_tensor.quantization
                )
                transpose_op = tfl.TransposeOperator([last_tensor, nhwc2nchw_perm_tensor], [transposed])
                transpose_op.extra_hints['direction'] = 'down'
                self.graph.add_operator(transpose_op)

                # Get the newly-generated node
                new_node_name = self.graph.tensor_node_map[transposed.name]
                new_node = self.graph.graph.vs.find(name=new_node_name)

                # Connect the transpose op to the graph
                self.graph.replace_next_tensors(last_node, new_node, transposed.name, [new_node_name])

                # Collect the unused connections
                for edge in last_node.out_edges():
                    target_vertex = edge.target_vertex
                    if target_vertex['name'] != new_node_name:
                        remove_edges.append(edge.index)

        # Remove the collected edges
        self.graph.graph.delete_edges(remove_edges)

    @class_conditional(lambda self: self.quantize_input_output_type is not None)
    def quantize_input_output_type_pass(self):
        remove_edges = []
        remove_vertices = []
        for i, name in enumerate(self.graph.inputs):
            if self.fuse_input_indices is not None:
                if i not in self.fuse_input_indices:
                    continue

            node_name = self.graph.tensor_node_map[name]
            node = self.graph.graph.vs.find(name=node_name)
            assert node['node_type'] == ExtendedOperator.INPUT_NODE

            # Update input tensor
            input_tensor = self.graph.tensor_map[node['outputs'][0]]
            input_type = str(input_tensor.dtype)
            if input_type == self.quantize_input_output_type:
                continue

            input_arr = input_tensor.tensor.copy()
            input_quantization = copy.deepcopy(input_tensor.quantization)
            if input_type == 'int8' and self.quantize_input_output_type == 'uint8':
                input_tensor.tensor = (input_tensor.tensor.astype('int32') + 128).astype('uint8')
                input_tensor.quantization.zero_point += 128
                input_tensor.dtype = input_tensor.tensor.dtype
            elif input_type == 'uint8' and self.quantize_input_output_type == 'int8':
                input_tensor.tensor = (input_tensor.tensor.astype('int32') - 128).astype('int8')
                input_tensor.quantization.zero_point -= 128
                input_tensor.dtype = input_tensor.tensor.dtype
            else:
                raise AssertionError(
                    f'Unsupported types: input_type: {input_type}, quantize_input_type:'
                    f' {self.quantize_input_output_type}'
                )

            # Create new quantize op
            requantized = self.create_transform_tensor(input_arr, quantization=input_quantization)
            quantize_op = tfl.QuantizeOperator([input_tensor], [requantized])
            self.graph.add_operator(quantize_op)

            # Get the newly-generated node
            new_node_name = self.graph.tensor_node_map[requantized.name]
            new_node = self.graph.graph.vs.find(name=new_node_name)

            # Connect the quantize op to the graph
            self.graph.replace_next_tensors(node, new_node, requantized.name, [new_node_name])

            # Collect the unused connections
            for edge in node.out_edges():
                target_vertex = edge.target_vertex
                if target_vertex['name'] != new_node_name:
                    remove_edges.append(edge.index)

        output_mapping = {}
        for i, name in enumerate(self.graph.outputs):
            if self.fuse_output_indices is not None:
                if i not in self.fuse_output_indices:
                    continue

            output_tensor = self.graph.tensor_map[name]
            output_type = str(output_tensor.dtype)
            if output_type == self.quantize_input_output_type:
                continue

            node_name = self.graph.tensor_node_map[name]
            node = self.graph.graph.vs.find(name=node_name)

            for edge in node.out_edges():
                next_node = edge.target_vertex

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_vertices.append(next_node.index)

            # Update output tensor
            output_arr = output_tensor.tensor.copy()
            output_quantization = copy.deepcopy(output_tensor.quantization)
            if output_type == 'int8' and self.quantize_input_output_type == 'uint8':
                output_arr = (output_arr.astype('int32') + 128).astype('uint8')
                output_quantization.zero_point += 128
            elif output_type == 'uint8' and self.quantize_input_output_type == 'int8':
                output_arr = (output_arr.astype('int32') - 128).astype('int8')
                output_quantization.zero_point -= 128
            else:
                raise AssertionError(
                    f'Unsupported types: output_type: {output_type}, quantize_input_type:'
                    f' {self.quantize_input_output_type}'
                )

            requantized = self.create_transform_tensor(output_arr, quantization=output_quantization)
            quantize_op = tfl.QuantizeOperator([output_tensor], [requantized])
            self.graph.add_operator(quantize_op)

            output_mapping[name] = requantized.name

        if len(output_mapping) > 0:
            new_outputs = []
            output_names = []
            for name in self.graph.outputs:
                if name in output_mapping:
                    new_outputs.append(output_mapping[name])
                    output_names.append(output_mapping[name])
                else:
                    new_outputs.append(name)

            self.graph.outputs.clear()
            self.graph.outputs.extend(new_outputs)
            self.graph.add_outputs(output_names)

        # Remove the collected edges & vertices
        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

    def output_transpose_pass(self):
        nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32')
        nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32')

        if isinstance(self.graph.output_transpose, (list, tuple)):
            assert len(self.graph.output_transpose) == len(self.graph.outputs)
        else:
            self.graph.output_transpose = [self.graph.output_transpose] * len(self.graph.outputs)

        filtered_dict = {}
        for i, (name, transpose) in enumerate(zip(self.graph.outputs, self.graph.output_transpose)):
            if name in filtered_dict:
                old_transpose = filtered_dict[name]
                assert (
                    transpose == old_transpose
                ), f"outputs {i} points to an exising tensor {name}, but their property `output_transpose` is different"
            else:
                filtered_dict[name] = transpose

        prev_modify_node_indices = {}
        prev_modify_next_indices = {}
        next_modify_node_indices = {}
        for name, transpose in filtered_dict.items():
            if name in self.graph.tensor_map:
                tensor = self.graph.tensor_map[name]
                if transpose is None:
                    transpose = len(tensor.shape) == 4
            else:
                transpose = False

            for i, n in enumerate(self.graph.outputs):
                if name == n:
                    self.graph.output_transpose[i] = transpose

            if transpose:
                node_name = self.graph.tensor_node_map[name]
                node = self.graph.graph.vs.find(name=node_name)
                tensor_idx = node['outputs'].index(name)

                prev_node = None
                if node['node_type'] == ExtendedOperator.DEQUANTIZE:
                    prev_node_name = self.graph.tensor_node_map[node['op'].inputs[0].name]
                    prev_node = self.graph.graph.vs.find(name=prev_node_name)

                if prev_node is None:
                    next_modify_node_indices.setdefault(node, set())
                    next_modify_node_indices[node].add(tensor_idx)
                else:
                    prev_modify_node_indices.setdefault(node, set())
                    prev_modify_node_indices[node].add(0)
                    prev_modify_next_indices.setdefault(node, set())
                    prev_modify_next_indices[node].add(tensor_idx)

        remove_edges = []
        remove_vertices = []
        actions = []
        for node, index in prev_modify_node_indices.items():
            next_indices = prev_modify_next_indices[node]
            op = node['op']
            tensor_names = [node['outputs'][i] for i in index]

            next_nodes = {}
            for edge in node.out_edges():
                if edge['label'] not in tensor_names:
                    continue

                if edge.index in remove_edges:
                    continue

                tensor_idx = tensor_names.index(edge['label'])
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] not in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
                    next_nodes.setdefault(tensor_idx, [])
                    next_nodes[tensor_idx].append(next_node)

            prev_nodes = []
            prev_output_indices = []
            for i in index:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, index, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
                    prev_new_out = self.create_transform_tensor(
                        np.transpose(prev_out.tensor, nchw2nhwc_perm), quantization=prev_out.quantization
                    )
                    tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                    prev_transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
                    prev_transpose_op.extra_hints['direction'] = 'up'
                    self.graph.add_operator(prev_transpose_op)
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_mapping = {}
            for i in next_indices:
                t = op.outputs[i]
                t.tensor = np.transpose(t.tensor, nchw2nhwc_perm)
                t.shape = t.tensor.shape

                if i in next_nodes:
                    new_t = self.create_transform_tensor(np.transpose(t.tensor, nhwc2nchw_perm))
                    perm_t = self.create_attr_tensor(nhwc2nchw_perm)

                    next_transpose_op = tfl.TransposeOperator([t, perm_t], [new_t])
                    next_transpose_op.extra_hints['direction'] = 'down'
                    self.graph.add_operator(next_transpose_op)

                    tensor_mapping[t.name] = new_t

            for nodes in next_nodes.values():
                for n in nodes:
                    next_op = n['op']
                    for i, t in enumerate(next_op.inputs):
                        if t.name in tensor_mapping:
                            actions.append((self.graph.replace_operator_input, (n, i, tensor_mapping[t.name])))

        for node, index in next_modify_node_indices.items():
            op = node['op']
            tensor_names = [node['outputs'][i] for i in index]
            out_nodes = []
            next_nodes = []
            next_edges = []
            for edge in node.out_edges():
                if edge['label'] not in tensor_names:
                    continue

                if edge.index in remove_edges:
                    continue

                next_node = self.graph.graph.vs[edge.target]
                tensor_idx = tensor_names.index(edge['label'])

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                elif next_node['node_type'] != ExtendedOperator.UNUSED_NODE:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

            remove_vertices.extend([x.index for x in out_nodes])
            remove_edges.extend([x.index for x in next_edges])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                if i not in index:
                    continue

                op_out.tensor = np.transpose(op_out.tensor, nchw2nhwc_perm)
                op_out.shape = op_out.tensor.shape

                perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
                new_out = self.create_transform_tensor(
                    np.transpose(op_out.tensor, nhwc2nchw_perm), quantization=op_out.quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                next_transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
                next_transpose_op.extra_hints['direction'] = 'up'
                self.graph.add_operator(next_transpose_op)

                tensor_node_dict[op_out.name] = (
                    self.graph.graph.vs.find(name=self.graph.tensor_node_map[new_out.name]),
                    new_out.name,
                )

            # Connect next edges and replace next tensors
            for edge in next_edges:
                old_name = edge['name']
                source, new_name = tensor_node_dict[old_name]
                target = edge.target_vertex
                self.graph.graph.add_edge(source, target, name=new_name, label=new_name)

                op = target['op']
                for i, op_input in enumerate(op.inputs):
                    if op_input.name == old_name:
                        op.inputs[i] = self.graph.tensor_map[new_name]
                        break

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

    def connect_unused_tensors_pass(self):
        filtered_nodes = self.graph.graph.vs.select(
            functools.partial(is_multi_output_op_node, graph_converter=self.graph.graph)
        )

        list_unpack_names = set([i for s in self.graph.iterable_map.values() for i in s])
        all_tensors = set(self.graph.graph.es['label'])
        names = []
        for node in filtered_nodes:
            output_names = node['outputs']

            # Recognizes the pattern SPLIT -> (RESHAPE, ..., RESHAPE)
            if not list_unpack_names.isdisjoint(set(output_names)):
                output_names = []
                outdegree = 0
                for edge in node.out_edges():
                    target_vertex = edge.target_vertex
                    if target_vertex['node_type'] == ExtendedOperator.RESHAPE:
                        outdegree += target_vertex.outdegree()
                        output_names.append(target_vertex['outputs'][0])

                    # Only nodes with partially unused tensors are supported
                    if outdegree == 0:
                        continue

            for out in output_names:
                if out not in all_tensors:
                    names.append(out)

        self.graph.add_outputs(names, ExtendedOperator.UNUSED_NODE)

    def output_list_unpack_pass(self):
        output_names = []
        unpacked_outputs = []
        for name in self.graph.outputs:
            if name in self.graph.iterable_map:
                names = self.graph.get_list_expanded_names(name)
                unpacked_outputs.extend(names)
                output_names.extend(names)
            else:
                unpacked_outputs.append(name)

        self.graph.outputs.clear()
        self.graph.outputs.extend(unpacked_outputs)
        self.graph.add_outputs(output_names)

    @class_conditional(lambda self: self.fuse_quant)
    def fuse_quant_dequant_nodes(self):
        edges = self.graph.graph.es.select(functools.partial(is_quant_dequant_edge, graph_converter=self.graph.graph))
        filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

        remove_vertices = []
        input_mapping = {}
        output_mapping = {}
        for prev, next in filtered_pairs:
            if prev['node_type'] == ExtendedOperator.INPUT_NODE:
                input_name = prev['outputs'][0]
                if self.fuse_input_indices is not None:
                    input_idx = self.graph.inputs.index(input_name)
                    if input_idx not in self.fuse_input_indices:
                        continue
                remove_vertices.append(prev)
                next['node_type'] = prev['node_type']
                next['op'] = None
                input_mapping.setdefault(input_name, [])
                input_mapping[input_name].extend(next['outputs'])
            else:
                if prev['op'] is not None:
                    prev_name = prev['op'].outputs[0].name
                    if self.fuse_output_indices is not None:
                        output_idx = self.graph.outputs.index(prev_name)
                        if output_idx not in self.fuse_output_indices:
                            continue
                    prev['node_type'] = next['node_type']
                    new_name = prev['op'].inputs[0].name
                    prev['op'] = None
                    output_mapping.setdefault(prev_name, [])
                    output_mapping[prev_name].append(new_name)
                remove_vertices.append(next)

        self.graph.graph.delete_vertices(remove_vertices)

        if len(input_mapping) > 0:
            new_inputs = []
            for name in self.graph.inputs:
                if name in input_mapping:
                    new_inputs.extend(input_mapping[name])
                else:
                    new_inputs.append(name)

            self.graph.inputs.clear()
            self.graph.inputs.extend(new_inputs)

        if len(output_mapping) > 0:
            new_outputs = []
            for name in self.graph.outputs:
                if name in output_mapping:
                    new_outputs.extend(output_mapping[name])
                else:
                    new_outputs.append(name)

            self.graph.outputs.clear()
            self.graph.outputs.extend(new_outputs)

    def optimize(self):
        # Input/output passes
        self.output_list_unpack_pass()
        self.input_transpose_pass()
        self.output_transpose_pass()

        # Connect unused tensors with special nodes
        self.connect_unused_tensors_pass()

        # Transpose, Reshape and NO-OP cleanup
        self.branch_reshape_expand_pass()
        self.fuse_simple_reshape_pass()
        self.branch_transpose_expand_pass()
        self.fuse_simple_transpose_pass()
        self.fuse_simple_gather_pass()
        for branch in (False, True):
            self.remove_noop_pass(branch)
        self.fuse_wrapped_reshape_within_transpose_pass()

        # Buffer folding, which is needed by the fusion passes below
        for _ in range(2):
            self.fold_reshape_buffer()
            self.fold_transpose_buffer()

        # Move `transpose` ops for the rewrite quantizable pass
        self.elementwise_op_transpose_passthrough_pass(quantizable_ops_only=True)
        self.branch_transpose_expand_pass()
        self.fuse_simple_transpose_pass()

        # Fuse reciprocal and sqrt
        self.fuse_reciprocal_sqrt()

        # Map quantizable ops to quantized kernels
        self.elementwise_op_quantize_passthrough_pass()

        # Remove consecutive dequantize and quantize nodes
        self.fuse_dequant_quant_pass(q_first=False)

        # OP fusion passes before transformation
        self.fuse_conv_fc_bn()
        self.fuse_activation()
        self.fuse_requantize()
        self.fuse_bn_conv()

        # Convert TinyNeuralNetwork ops to TFLite ops
        self.transform_graph()

        # OP fusion passes after transformation
        self.fuse_bmm_add_pass()
        self.fuse_activation()

        # Transpose and reshape cleanup
        self.branch_reshape_expand_pass()
        self.branch_transpose_expand_pass()
        self.fuse_simple_transpose_pass()
        self.fuse_simple_reshape_pass()

        # Branch transpose & reshape cleanup
        for i in range(11):
            t_count = self.elementwise_op_transpose_passthrough_pass()
            self.branch_transpose_expand_pass()
            self.fuse_simple_transpose_pass()

            r_count = self.elementwise_op_reshape_passthrough_pass()
            self.branch_reshape_expand_pass()
            self.fuse_simple_reshape_pass()

            c_count = self.elementwise_reshape_transpose_passthrough_pass()
            self.branch_transpose_expand_pass()
            self.fuse_simple_transpose_pass()

            if t_count + r_count + c_count == 0:
                log.debug(f'elem p/t pass finished in {i + 1} steps')
                break

        # Other cleanups
        self.fuse_simple_slice_pass()
        for branch in (False, True):
            self.remove_noop_pass(branch)
        self.fuse_wrapped_reshape_within_transpose_pass()

        # Buffer folding
        for _ in range(2):
            self.fold_reshape_buffer()
            self.fold_transpose_buffer()

        # Transpose and reshape cleanup
        for _ in range(2):
            self.transpose_to_reshape_pass()
            self.branch_reshape_expand_pass()
            self.fuse_simple_reshape_pass()
            self.fuse_simple_transpose_pass()

        self.lower_transpose_dim_pass()

        # Some advanced fusion logic
        self.fuse_conv2d_gather()

        # Remove consecutive dequantize and quantize nodes
        self.fuse_dequant_quant_pass(q_first=True)

        # Fuse reciprocal and sqrt
        self.fuse_reciprocal_sqrt()

        # Remove additional tile nodes before elementwise ops
        self.remove_tile_before_binary_elementwise_ops()

        # Fuse activation
        self.fuse_activation()

        # Fuse quant/dequant nodes
        self.fuse_quant_dequant_nodes()

        # Input output quantize type
        self.quantize_input_output_type_pass()

        # Fuse same padding
        self.fuse_same_padding()
        self.fuse_same_padding_slicing()

        self.fuse_gather_conv2d()

        # Group conv & deconv
        self.group_conv_rewrite_pass()
        self.group_deconv_rewrite_pass()

        # TFLite micro specific
        self.cat_split_pass()
        self.split_requantize()

        # Group the same tensors into one
        self.group_tensors_pass()

        # Final cleanup
        self.cleanup_dead_nodes()


def is_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type']
        in (ExtendedOperator.GENERIC_CONV, ExtendedOperator.GENERIC_DECONV, ExtendedOperator.FULLY_CONNECTED)
        and target_vertex['node_type'] == ExtendedOperator.BATCH_NORM
        and source_vertex.outdegree() == 1
        and target_vertex['op'].inputs[1].buffer is not None
        and target_vertex['op'].inputs[2].buffer is not None
        and source_vertex['op'].inputs[1].buffer is not None
        and (
            target_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
            or source_vertex['op'].fusedActivationFunction
            in (ActivationFunctionType.NONE, target_vertex['op'].fusedActivationFunction)
        )
    )


def is_rev_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        target_vertex['node_type'] == ExtendedOperator.GENERIC_CONV
        and source_vertex['node_type'] == ExtendedOperator.BATCH_NORM
        and source_vertex.outdegree() == 1
        and source_vertex['op'].inputs[1].buffer is not None
        and source_vertex['op'].inputs[2].buffer is not None
        and target_vertex['op'].inputs[1].buffer is not None
        and source_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
    )


def is_padding_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] in (ExtendedOperator.PAD, ExtendedOperator.PADV2)
        and (
            len(source_vertex['op'].inputs) == 2
            or (
                len(source_vertex['op'].inputs) == 3
                and source_vertex['op'].inputs[2].dtype == np.dtype('float32')
                and (
                    (
                        source_vertex['op'].inputs[2].tensor[0] == 0.0
                        and target_vertex['node_type'] != ExtendedOperator.MAX_POOL_2D
                    )
                    or (
                        source_vertex['op'].inputs[2].tensor[0] == np.finfo(np.float32).min
                        and target_vertex['node_type'] == ExtendedOperator.MAX_POOL_2D
                    )
                )
            )
        )
        and target_vertex['node_type']
        in (
            ExtendedOperator.CONV_2D,
            ExtendedOperator.CONV_3D,
            ExtendedOperator.DEPTHWISE_CONV_2D,
            ExtendedOperator.MAX_POOL_2D,
        )
        and source_vertex.outdegree() == 1
        and target_vertex['op'].padding == Padding.VALID
    )


def is_slicing_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        target_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
        and (
            len(target_vertex['op'].inputs) == 3
            or (len(target_vertex['op'].inputs) == 4 and np.all(target_vertex['op'].inputs[3].tensor == 1))
        )
        and source_vertex['node_type']
        in (
            ExtendedOperator.TRANSPOSE_CONV,
            ExtendedOperator.CONV_3D_TRANSPOSE,
        )
        and source_vertex.outdegree() == 1
        and source_vertex['op'].padding == Padding.VALID
    )


def is_requantize_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type']
        in (
            ExtendedOperator.FULLY_CONNECTED,
            ExtendedOperator.GENERIC_CONV,
            ExtendedOperator.ADD,
            ExtendedOperator.SUB,
            ExtendedOperator.MUL,
            ExtendedOperator.DIV,
            ExtendedOperator.MAX_POOL_2D,
            ExtendedOperator.AVERAGE_POOL_2D,
            ExtendedOperator.GENERIC_DECONV,
        )
        and source_vertex['op'].outputs[0].quantization is not None
        and target_vertex['node_type'] == ExtendedOperator.QUANTIZE
    )


def is_activ_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type']
        in (
            ExtendedOperator.FULLY_CONNECTED,
            ExtendedOperator.GENERIC_CONV,
            ExtendedOperator.ADD,
            ExtendedOperator.SUB,
            ExtendedOperator.MUL,
            ExtendedOperator.DIV,
            ExtendedOperator.MAX_POOL_2D,
            ExtendedOperator.AVERAGE_POOL_2D,
            ExtendedOperator.GENERIC_DECONV,
        )
        and target_vertex['node_type'] in (ExtendedOperator.RELU, ExtendedOperator.RELU6)
        and source_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
        and source_vertex.outdegree() == 1
    )


def is_requantize_node(vertex: ig.Vertex, graph_converter: ig.Graph):
    return (
        vertex['node_type'] == ExtendedOperator.QUANTIZE
        and vertex['op'].inputs[0].quantization is not None
        and vertex['op'].outputs[0].quantization is not None
    )


def is_large_cat_node(vertex: ig.Vertex, graph_converter: ig.Graph):
    return vertex['node_type'] == ExtendedOperator.CONCATENATION and len(vertex['op'].inputs) > 10


def is_high_dim_transpose_node(vertex: ig.Vertex, graph_converter: ig.Graph, max_transpose_dims: int):
    return vertex['node_type'] == ExtendedOperator.TRANSPOSE and vertex['op'].inputs[1].tensor.size > max_transpose_dims


def is_group_conv_node(vertex: ig.Vertex, graph_converter: ig.Graph):
    return (
        vertex['node_type'] == ExtendedOperator.CONV_2D
        and vertex['op'].inputs[0].shape[3] != vertex['op'].inputs[1].shape[3]
    )


def is_group_deconv_node(vertex: ig.Vertex, graph_converter: ig.Graph):
    return (
        vertex['node_type'] == ExtendedOperator.TRANSPOSE_CONV
        and vertex['op'].outputs[0].shape[3] != vertex['op'].inputs[1].shape[0]
    )


def is_transformable_node(vertex: ig.Vertex, graph_converter: ig.Graph):
    return vertex['node_type'] <= ExtendedOperator.BATCH_NORM and vertex.outdegree() >= 1


def is_transformable_transpose_node(vertex: ig.Vertex, graph_converter: ig.Graph):
    return (
        vertex['node_type'] == ExtendedOperator.TRANSPOSE
        and vertex.outdegree() >= 1
        and is_transpose_same_to_reshape_op(vertex['op'])
    )


def is_multi_output_op_node(vertex: ig.Vertex, graph_converter: ig.Graph):
    return vertex['node_type'] >= 0 and len(vertex['outputs']) > 1 and vertex.outdegree() > 0


def is_quantize_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph, with_lstm: bool):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        (
            source_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
            and is_quantizable_rewrite_op(target_vertex['node_type'], target_vertex['op'], with_lstm)
        )
        or (
            target_vertex['node_type'] == ExtendedOperator.QUANTIZE
            and is_quantizable_rewrite_op(source_vertex['node_type'], source_vertex['op'], with_lstm)
        )
    ) and target_vertex['op'].inputs[0].name in source_vertex['outputs']


def is_transpose_reshape_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        (
            source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
            and target_vertex['node_type'] == ExtendedOperator.RESHAPE
        )
        or (
            target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
            and source_vertex['node_type'] == ExtendedOperator.RESHAPE
        )
    ) and target_vertex['op'].inputs[0].name in source_vertex['outputs']


def is_transpose_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph, quantizable_ops_only: bool):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    if quantizable_ops_only:
        is_unary = is_elementwise_unary_quantizable_op
        is_binary = is_elementwise_binary_quantizable_op
    else:
        is_unary = is_elementwise_unary_op
        is_binary = is_elementwise_binary_op

    return (
        (
            source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
            and (
                is_unary(target_vertex['node_type'], target_vertex['op'])
                or is_binary(target_vertex['node_type'], target_vertex['op'])
            )
        )
        or (
            target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
            and (
                is_unary(source_vertex['node_type'], source_vertex['op'])
                or is_binary(source_vertex['node_type'], source_vertex['op'])
            )
        )
    ) and (
        (
            target_vertex['node_type'] != ExtendedOperator.SPLIT
            and target_vertex['op'].inputs[0].name in source_vertex['outputs']
        )
        or (
            target_vertex['node_type'] == ExtendedOperator.SPLIT
            and target_vertex['op'].inputs[1].name in source_vertex['outputs']
        )
    )


def is_reshape_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        (
            source_vertex['node_type'] == ExtendedOperator.RESHAPE
            and (
                is_elementwise_unary_op(target_vertex['node_type'], target_vertex['op'])
                or is_elementwise_binary_op(target_vertex['node_type'], target_vertex['op'])
            )
        )
        or (
            target_vertex['node_type'] == ExtendedOperator.RESHAPE
            and (
                is_elementwise_unary_op(source_vertex['node_type'], source_vertex['op'])
                or is_elementwise_binary_op(source_vertex['node_type'], source_vertex['op'])
            )
        )
    ) and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name


def is_elementwise_reduce_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
    return (
        op_code
        in (
            ExtendedOperator.SUM,
            ExtendedOperator.ARG_MIN,
            ExtendedOperator.ARG_MAX,
            ExtendedOperator.REDUCE_MIN,
            ExtendedOperator.REDUCE_MAX,
            ExtendedOperator.REDUCE_PROD,
        )
        and len(op.inputs[0].shape) == len(op.outputs[0].shape)
    ) or (
        op_code == ExtendedOperator.MEAN
        and len(op.inputs[0].shape) == len(op.outputs[0].shape)
        and (
            len(op.inputs[0].shape) != 4
            or (
                not np.array_equal(op.inputs[1].tensor, np.array([1, 2], dtype='int32'))
                and not np.array_equal(op.inputs[1].tensor, np.array([2, 1], dtype='int32'))
            )
        )
    )


def is_elementwise_unary_quantizable_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
    return op_code in (
        ExtendedOperator.SOFTMAX,
        ExtendedOperator.LOG_SOFTMAX,
    )


def is_elementwise_binary_quantizable_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
    return False


def is_elementwise_unary_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
    return op_code in (
        ExtendedOperator.RELU,
        ExtendedOperator.SIN,
        ExtendedOperator.COS,
        ExtendedOperator.TANH,
        ExtendedOperator.ELU,
        ExtendedOperator.PRELU,
        ExtendedOperator.EXP,
        ExtendedOperator.LOG,
        ExtendedOperator.NEG,
        ExtendedOperator.FLOOR,
        ExtendedOperator.RELU6,
        ExtendedOperator.QUANTIZE,
        ExtendedOperator.DEQUANTIZE,
        ExtendedOperator.SQRT,
        ExtendedOperator.RSQRT,
        ExtendedOperator.CAST,
        ExtendedOperator.LOGISTIC,
        ExtendedOperator.HARD_SWISH,
        ExtendedOperator.LEAKY_RELU,
        ExtendedOperator.SPLIT,
        ExtendedOperator.SPLIT_V,
        ExtendedOperator.UNPACK,
        ExtendedOperator.PAD,
        ExtendedOperator.PADV2,
        ExtendedOperator.MIRROR_PAD,
        ExtendedOperator.SLICE,
        ExtendedOperator.STRIDED_SLICE,
        ExtendedOperator.TILE,
        ExtendedOperator.GATHER,
        ExtendedOperator.ABS,
    ) or is_elementwise_reduce_op(op_code, op)


def is_quantizable_rewrite_op(op_code: ExtendedOperator, op: tfl.BaseOperator, with_lstm: bool):
    return op_code in (
        ExtendedOperator.BATCH_MATMUL,
        ExtendedOperator.SOFTMAX,
        ExtendedOperator.LOG_SOFTMAX,
        ExtendedOperator.ABS,
        ExtendedOperator.SUM,
        ExtendedOperator.DIV,
        ExtendedOperator.RSQRT,
        ExtendedOperator.MAXIMUM,
        ExtendedOperator.MINIMUM,
    ) or (with_lstm and op_code == ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM)


def is_elementwise_binary_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
    return (
        op_code
        in (
            ExtendedOperator.CONCATENATION,
            ExtendedOperator.PACK,
            ExtendedOperator.ADD,
            ExtendedOperator.SUB,
            ExtendedOperator.MUL,
            ExtendedOperator.DIV,
            ExtendedOperator.MAXIMUM,
            ExtendedOperator.MINIMUM,
            ExtendedOperator.SQUARED_DIFFERENCE,
        )
        and len(op.inputs) >= 2
    )


def is_non_passthrough_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
    return op_code in (
        ExtendedOperator.CONV_2D,
        ExtendedOperator.AVERAGE_POOL_2D,
        ExtendedOperator.DEPTHWISE_CONV_2D,
        ExtendedOperator.MAX_POOL_2D,
    )


def is_ending_with_noop_edge(edge: ig.Edge, graph_converter: ig.Graph, branch: bool = False):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    if branch:
        source_cond_var = source_vertex.outdegree() >= 1
    else:
        source_cond_var = source_vertex.outdegree() == 1

    return (
        source_cond_var
        and target_vertex.outdegree() >= 1
        and target_vertex['op'] is not None
        and target_vertex['op'].inputs[0].name in source_vertex['outputs']
        and (
            (
                target_vertex['node_type'] == ExtendedOperator.RESHAPE
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
                and (np.diff(target_vertex['op'].inputs[1].tensor) == 1).all()
            )
            or (
                target_vertex['node_type']
                in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD)
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.TILE
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.CONCATENATION
                and len(target_vertex['op'].inputs) == 1
                and len(target_vertex['op'].outputs) == 1
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.GATHER
                and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
                and (np.diff(target_vertex['op'].inputs[1].tensor) == 1).all()
            )
            or (
                target_vertex['node_type'] == ExtendedOperator.CAST
                and target_vertex['op'].inDataType == target_vertex['op'].outDataType
            )
        )
    )


def is_bmm_add_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    out_dim_idx = None
    if source_vertex['node_type'] == ExtendedOperator.BATCH_MATMUL:
        out_dim_idx = -1
    elif source_vertex['node_type'] == ExtendedOperator.FULLY_CONNECTED:
        out_dim_idx = 0

    return (
        out_dim_idx is not None
        and target_vertex['node_type'] == ExtendedOperator.ADD
        and source_vertex['op'].inputs[0].tensor.ndim >= 2
        and source_vertex['op'].inputs[1].tensor.ndim == 2
        and target_vertex['op'].inputs[1].tensor.ndim == 1
        and target_vertex['op'].inputs[1].shape[0] == source_vertex['op'].inputs[1].shape[out_dim_idx]
        and source_vertex.outdegree() == 1
        and target_vertex.outdegree() >= 1
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    )


def is_wrapped_reshape_within_transpose_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        (
            (
                target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
                and source_vertex['node_type'] == ExtendedOperator.RESHAPE
            )
            or (
                source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
                and target_vertex['node_type'] == ExtendedOperator.RESHAPE
            )
        )
        and source_vertex.outdegree() == 1
        and target_vertex.outdegree() >= 1
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    )


def is_slice_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
        and source_vertex.outdegree() == 1
        and target_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
        and target_vertex.outdegree() >= 1
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
        and source_vertex['op'].inputs[1].buffer is not None
        and source_vertex['op'].inputs[2].buffer is not None
    )


def is_transpose_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
        and source_vertex.outdegree() == 1
        and target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
        and target_vertex.outdegree() >= 1
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    )


def is_gather_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.GATHER
        and source_vertex.outdegree() == 1
        and target_vertex['node_type'] == ExtendedOperator.GATHER
        and target_vertex.outdegree() >= 1
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
        and source_vertex['op'].axis == target_vertex['op'].axis
        and source_vertex['op'].batchDims == target_vertex['op'].batchDims
    )


def is_reshape_branch_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.RESHAPE
        and source_vertex.outdegree() > 1
        and target_vertex['node_type'] == ExtendedOperator.RESHAPE
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    )


def is_transpose_branch_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
        and source_vertex.outdegree() > 1
        and target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    )


def is_dequant_quant_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph, q_first: bool):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    if q_first:
        cond = (
            source_vertex['node_type'] == ExtendedOperator.QUANTIZE
            and target_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
        )

    else:
        cond = (
            source_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
            and target_vertex['node_type'] == ExtendedOperator.QUANTIZE
        )

    return (
        cond
        and source_vertex.outdegree() == 1
        and target_vertex.outdegree() >= 1
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    )


def is_reshape_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.RESHAPE
        and source_vertex.outdegree() == 1
        and target_vertex['node_type'] == ExtendedOperator.RESHAPE
        and target_vertex.outdegree() >= 1
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    )


def is_constant_transpose_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.CONSTANT_NODE
        and target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
        and target_vertex.outdegree() >= 1
    )


def is_constant_reshape_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.CONSTANT_NODE
        and target_vertex['node_type'] == ExtendedOperator.RESHAPE
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
        and target_vertex.outdegree() >= 1
    )


def is_quant_dequant_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]
    return (
        source_vertex['node_type'] == ExtendedOperator.INPUT_NODE
        and target_vertex['node_type'] == ExtendedOperator.QUANTIZE
        and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
    ) or (
        source_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
        and target_vertex['node_type'] == ExtendedOperator.OUTPUT_NODE
    )


def is_transpose_same_to_reshape_op(op: tfl.BaseOperator):
    num_elements = np.prod(op.inputs[0].shape)

    input_shape = np.array(op.inputs[0].shape, dtype='int32')
    output_shape = np.array(op.outputs[0].shape, dtype='int32')

    if np.array_equal(input_shape[input_shape != 1], output_shape[output_shape != 1]):
        input_tensor = np.arange(num_elements).reshape(input_shape)
        perm = op.inputs[1].tensor
        new_tensor = np.transpose(input_tensor, perm)

        return np.array_equal(new_tensor.flatten(), input_tensor.flatten())
    else:
        return False


def is_conv2d_gather_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    return (
        source_vertex['node_type'] == ExtendedOperator.CONV_2D
        and target_vertex['node_type'] == ExtendedOperator.GATHER
        and source_vertex.outdegree() == 1
        and target_vertex['op'].inputs[1].buffer is not None
        and target_vertex['op'].axis == 3
        and source_vertex['op'].inputs[1].tensor.shape[0] == target_vertex['op'].inputs[1].tensor.shape[0]
    )


def is_gather_conv2d_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    return (
        source_vertex['node_type'] == ExtendedOperator.GATHER
        and target_vertex['node_type'] == ExtendedOperator.CONV_2D
        and source_vertex.outdegree() == 1
        and source_vertex['op'].inputs[1].buffer is not None
        and source_vertex['op'].axis == 3
        and source_vertex['op'].inputs[1].tensor.shape[0] == target_vertex['op'].inputs[1].tensor.shape[3]
    )


def is_reciprocal_sqrt_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    return (
        source_vertex['node_type'] == ExtendedOperator.SQRT
        and target_vertex['node_type'] == ExtendedOperator.DIV
        and source_vertex.outdegree() == 1
    )


def is_tile_binary_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
    source_vertex = graph_converter.vs[edge.source]
    target_vertex = graph_converter.vs[edge.target]

    return (
        source_vertex['node_type'] == ExtendedOperator.TILE
        and target_vertex['node_type']
        in (
            ExtendedOperator.ADD,
            ExtendedOperator.SUB,
            ExtendedOperator.MUL,
            ExtendedOperator.DIV,
        )
        and source_vertex.outdegree() == 1
    )


def op_input_dims(op: tfl.BaseOperator):
    dim_indices = None

    if isinstance(op, (tfl.ConcatenationOperator, tfl.GatherOperator, tfl.PackOperator, tfl.UnpackOperator)):
        dim_indices = op.axis
    elif isinstance(op, tfl.SplitOperator):
        dim_indices = op.inputs[0].tensor.item()
    elif isinstance(op, tfl.SplitVOperator):
        dim_indices = op.inputs[2].tensor.item()
    elif isinstance(op, (tfl.PadOperator, tfl.Padv2Operator, tfl.MirrorPadOperator)):
        pads = np.sum(op.inputs[1].tensor, axis=-1)
        nonzero_idx = np.nonzero(pads)[0]
        # TODO: support multi indices
        if nonzero_idx.size == 1:
            dim_indices = nonzero_idx[0]
    elif isinstance(op, tfl.PreluOperator):
        w_shape = np.array(op.inputs[1].shape, dtype='int32')
        nonzero_idx = np.nonzero(w_shape != 1)[0]
        if nonzero_idx.size == 1:
            dim_indices = nonzero_idx[0] + 1
    elif isinstance(op, (tfl.SliceOperator, tfl.StridedSliceOperator, tfl.TileOperator)):
        old_shape = np.array(op.inputs[0].shape)
        new_shape = np.array(op.outputs[0].shape)
        diff = new_shape - old_shape
        nonzero_idx = np.nonzero(diff)[0]
        # TODO: support multi indices
        if nonzero_idx.size == 1:
            dim_indices = nonzero_idx[0]
    elif isinstance(
        op,
        (
            tfl.SumOperator,
            tfl.MeanOperator,
            tfl.ArgMinOperator,
            tfl.ArgMaxOperator,
            tfl.ReduceMinOperator,
            tfl.ReduceMaxOperator,
            tfl.ReduceProdOperator,
        ),
    ):
        # TODO: support multi indices
        if op.inputs[1].tensor.size == 1:
            dim_indices = op.inputs[1].tensor[0]
    return dim_indices


def op_input_indices(op: tfl.BaseOperator):
    if isinstance(op, (tfl.ConcatenationOperator, tfl.PackOperator)):
        input_indices = range(len(op.inputs))
    elif isinstance(op, tfl.SplitOperator):
        input_indices = (1,)
    elif isinstance(op, (tfl.BatchMatmulOperator, tfl.MinimumOperator, tfl.MaximumOperator)):
        input_indices = range(2)
    elif isinstance(
        op, (tfl.AddOperator, tfl.SubOperator, tfl.MulOperator, tfl.DivOperator, tfl.SquaredDifferenceOperator)
    ):
        if len(op.inputs[1].shape) == 1 and op.inputs[1].shape[0] == 1:
            input_indices = range(1)
        elif len(op.inputs[0].shape) == 1 and op.inputs[0].shape[0] == 1:
            input_indices = (1,)
        else:
            input_indices = range(2)
    else:
        input_indices = range(1)

    return input_indices


def fuse_bn_weight(eps, scale, var, weight, transpose):
    if transpose:
        shape = [1, -1] + [1] * (len(weight.shape) - 2)
    else:
        shape = [-1, 1] + [1] * (len(weight.shape) - 2)

    inv = 1 / np.sqrt(var + eps)

    return weight * (scale * inv).reshape(shape)


def fuse_bn_bias(eps, scale, var, mean, bn_b, activ_b):
    inv = 1 / np.sqrt(var + eps)
    if activ_b is not None:
        if activ_b.shape != mean.shape and activ_b.ndim == 1 and activ_b.size == 1:
            activ_b = activ_b.repeat(mean.size)
        return (activ_b - mean) * inv * scale + bn_b
    else:
        return (-mean) * inv * scale + bn_b


def fuse_rev_bn_weight(eps, scale, var, weight):
    shape = [1, -1] + [1] * (len(weight.shape) - 2)

    inv = 1 / np.sqrt(var + eps)

    return weight * (scale * inv).reshape(shape)


def fuse_rev_bn_bias(eps, scale, var, mean, bn_b, activ_b, weight):
    reduced_dims = tuple([i for i in range(len(weight.shape)) if i > 1])

    inv = 1 / np.sqrt(var + eps)
    fused_b = bn_b - mean * inv * scale

    if weight.shape[1] == 1 and mean.shape[0] > 1:
        offset_b = (weight.sum(reduced_dims) * fused_b.reshape(-1, 1)).reshape(-1)
    else:
        offset_b = np.matmul(weight.sum(reduced_dims), fused_b.reshape(-1, 1)).reshape(-1)

    if activ_b is not None:
        if activ_b.shape != mean.shape and activ_b.ndim == 1 and activ_b.size == 1:
            activ_b = activ_b.repeat(mean.size)
        return activ_b + offset_b
    else:
        return offset_b


def fuse_slices(seq: typing.Iterable[ig.Vertex]):
    cur_start = None
    cur_end = None
    cur_strides = None
    for node in seq:
        assert node['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
        next_start = node['op'].inputs[1].tensor
        if cur_strides is None:
            cur_strides = np.ones_like(next_start, dtype='int32')
        if cur_start is None:
            cur_start = np.zeros_like(next_start, dtype='int32')
        if node['node_type'] == ExtendedOperator.SLICE:
            next_size = node['op'].inputs[2].tensor
            next_end = cur_start + (next_start + next_size) * cur_strides
            next_strides = np.ones_like(next_start, dtype='int32')
        else:
            next_end = node['op'].inputs[2].tensor
            next_end = cur_start + next_end * cur_strides
            next_strides = node['op'].inputs[3].tensor
        if cur_end is None:
            cur_start = next_start
            cur_end = next_end
            cur_strides = next_strides
        else:
            cur_start += next_start * cur_strides
            cur_end = np.min((cur_end, next_end), axis=0)
            cur_strides = cur_strides * next_strides
    return cur_start, cur_end, cur_strides


def fuse_transpose_perms(seq: typing.Iterable[ig.Vertex]):
    cur_perm = None
    for node in seq:
        assert node['node_type'] in (ExtendedOperator.TRANSPOSE, ExtendedOperator.GATHER)
        next_perm = node['op'].inputs[1].tensor
        if cur_perm is None:
            cur_perm = next_perm
        else:
            cur_perm = cur_perm[next_perm]
    return cur_perm


def fuse_transpose_perms_extended(seq: typing.Iterable[ig.Vertex]):
    cur_perm = None
    # Reverse the sequence if dim is expanding
    if seq[1]['node_type'] == ExtendedOperator.RESHAPE:
        if len(seq[1]['op'].inputs[0].shape) < len(seq[1]['op'].outputs[0].shape):
            seq = list(reversed(list(seq)))
    for node in seq:
        if node['node_type'] == ExtendedOperator.TRANSPOSE:
            next_perm = node['op'].inputs[1].tensor
            if cur_perm is None:
                cur_perm = next_perm
            else:
                cur_perm = cur_perm[next_perm]
        elif node['node_type'] == ExtendedOperator.RESHAPE:
            if len(seq[1]['op'].inputs[0].shape) > len(seq[1]['op'].outputs[0].shape):
                old_shape = node['op'].inputs[0].shape
                new_shape = node['op'].outputs[0].shape
            else:
                new_shape = node['op'].inputs[0].shape
                old_shape = node['op'].outputs[0].shape

            if old_shape != new_shape:
                if len(old_shape) != len(new_shape):
                    new_shape_padded = list(new_shape) + [None] * (len(old_shape) - len(new_shape))
                    next_perm = []
                    new_idx = 0
                    while new_idx < len(new_shape):
                        for old, item in zip(old_shape, cur_perm):
                            if old == new_shape_padded[new_idx] and item not in next_perm:
                                next_perm.append(item)
                                new_idx += 1
                    cur_perm = np.argsort(next_perm)
                else:
                    mapping = {}
                    for i in range(len(new_shape)):
                        mapping.setdefault(new_shape[i], [])
                        mapping[new_shape[i]].append(i)
                    next_perm = [0] * len(old_shape)
                    for i in range(len(old_shape)):
                        next_perm[i] = mapping[old_shape[i]].pop(0)
                    cur_perm = cur_perm[next_perm]

    return cur_perm


def fuse_connected_edges(
    filtered_pairs: typing.List[typing.Iterable[ig.Vertex]],
) -> typing.List[typing.Iterable[ig.Vertex]]:
    while True:
        heads = {n[0]: i for i, n in enumerate(filtered_pairs)}
        tails = {n[-1]: i for i, n in enumerate(filtered_pairs)}
        connectables = heads.keys() & tails.keys()
        if len(connectables) > 0:
            curr_filtered = []
            for seq in filtered_pairs:
                head_connectable = seq[0] in connectables
                preserve = head_connectable and filtered_pairs[tails[seq[0]]][0] in connectables
                if preserve:
                    curr_filtered.append(seq)
                elif not head_connectable:
                    if seq[-1] in connectables:
                        curr_filtered.append(seq + filtered_pairs[heads[seq[-1]]][1:])
                    else:
                        curr_filtered.append(seq)
            filtered_pairs = curr_filtered
        else:
            break

    return filtered_pairs


def is_simple_reshape(orig_shape, new_shape, mapping: typing.Optional[typing.Dict[int, int]] = None):
    if orig_shape == new_shape:
        if mapping is not None:
            for i in range(len(orig_shape)):
                mapping[i] = i
        return True

    i = 0
    j = 0

    while True:
        if i == len(orig_shape) and j == len(new_shape):
            break
        elif i == len(orig_shape):
            if new_shape[j] == 1:
                j += 1
            else:
                break
        elif j == len(new_shape):
            if orig_shape[i] == 1:
                i += 1
            else:
                break
        elif orig_shape[i] == new_shape[j]:
            if mapping is not None:
                mapping[i] = j
            i += 1
            j += 1
        elif orig_shape[i] == 1:
            i += 1
        elif new_shape[j] == 1:
            j += 1
        else:
            break

    if i != len(orig_shape) or j != len(new_shape):
        return False
    else:
        return True


def reshape_mapping(shape_1, shape_2):
    i = 0
    j = 0
    acc_l = 1
    start_l = 0
    acc_r = 1
    start_r = 0
    mapping_l = []
    mapping_r = []
    sign = None
    while i < len(shape_1) or j < len(shape_2):
        if i < len(shape_1) and j < len(shape_2):
            if start_l == i and start_r == j and shape_1[i] == shape_2[j]:
                mapping_l.append([i])
                mapping_r.append([j])
                acc_l = 1
                acc_r = 1
                i += 1
                j += 1
                start_l = i
                start_r = j
                sign = None
            else:
                if sign in ('l', None):
                    acc_l = shape_1[i] * acc_l
                if sign in ('r', None):
                    acc_r = shape_2[j] * acc_r
                if acc_l == acc_r:
                    mapping_l.append(list(range(start_l, i + 1)))
                    mapping_r.append(list(range(start_r, j + 1)))
                    acc_l = 1
                    acc_r = 1
                    i += 1
                    j += 1
                    start_l = i
                    start_r = j
                    sign = None
                elif acc_l < acc_r:
                    sign = 'l'
                    i += 1
                else:
                    sign = 'r'
                    j += 1
        elif i < len(shape_1):
            assert shape_1[i] == 1
            mapping_l[-1].append(i)
            i += 1
        else:
            assert shape_2[j] == 1
            mapping_r[-1].append(j)
            j += 1
    non_one_mapping_l = []
    non_one_mapping_r = []
    for ml, mr in zip(mapping_l, mapping_r):
        new_ml = [i for i in ml if shape_1[i] != 1]
        new_mr = [j for j in mr if shape_2[j] != 1]
        if len(new_ml) > 0 and len(new_mr) > 0:
            non_one_mapping_l.append(new_ml)
            non_one_mapping_r.append(new_mr)
    return mapping_l, mapping_r, non_one_mapping_l, non_one_mapping_r


def elinimate_sequences(
    graph_converter: CommonGraph,
    filtered_pairs: typing.List[typing.Iterable[ig.Vertex]],
    remove_first_pred: typing.Union[bool, typing.Callable] = False,
    remove_first_node_action: typing.Optional[typing.Callable] = None,
    remove_last_pred: typing.Union[bool, typing.Callable] = True,
    remove_last_node_action: typing.Optional[typing.Callable] = None,
    skip_pred: typing.Union[bool, typing.Callable] = False,
    input_idx: int = 0,
    force_forward_input: bool = False,
):
    remove_ids = []
    actions = []
    for seq in filtered_pairs:
        first_node = seq[0]
        last_node = seq[-1]

        if type(skip_pred) is bool:
            skip = skip_pred
        elif skip_pred is not None:
            skip = skip_pred(seq)

        if skip:
            continue

        if type(remove_first_pred) is bool:
            remove_first = remove_first_pred
            custom_data = None
        elif remove_first_pred is not None:
            remove_first, custom_data = remove_first_pred(seq)

        if type(remove_last_pred) is bool:
            remove_last = remove_last_pred
            custom_data_last = None
        elif remove_last_pred is not None:
            remove_last, custom_data_last = remove_last_pred(seq)

        # If the first node can also be eliminated, then set the previous node as the first node
        if remove_first:
            first_node = graph_converter.graph.vs.find(
                name=graph_converter.tensor_node_map[first_node['op'].inputs[input_idx].name]
            )

        if not remove_last:
            last_node = seq[-2]

        output_idx = 0
        if first_node == seq[0]:
            next_idx = 1
        else:
            next_idx = 0
        output_name = seq[next_idx]['op'].inputs[input_idx].name
        output_idx = first_node['outputs'].index(output_name)

        # We use the forward input tensor under the following circumstances.
        # 1. If the previous node before the sequence is an input node
        # 2. If the first node has multiple outputs and the last node doesn't connect to output nodes
        use_forward_input = False
        if first_node['node_type'] == ExtendedOperator.INPUT_NODE:
            use_forward_input = True

        branch = first_node.outdegree() > 1

        has_output_nodes = False
        for edge in last_node.out_edges():
            target_vertex = edge.target_vertex
            if target_vertex['node_type'] in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
                if use_forward_input:
                    # Cannot optimize away ops between i/o nodes
                    skip = True
                else:
                    has_output_nodes = True
                break

        if branch:
            output_outdegree = 0
            for edge in first_node.out_edges():
                target_vertex = edge.target_vertex
                if target_vertex == seq[next_idx]:
                    continue
                if target_vertex['node_type'] in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
                    if has_output_nodes and edge['label'] == output_name:
                        output_outdegree += 1
                        break
                else:
                    names = [t.name for t in target_vertex['op'].inputs]
                    if output_name in names:
                        output_outdegree += 1
                        break

            if not has_output_nodes:
                use_forward_input = True
            elif output_outdegree > 0:
                skip = True

        if force_forward_input and not use_forward_input:
            if not has_output_nodes:
                use_forward_input = True
            else:
                skip = True

        if skip:
            continue

        if use_forward_input:
            # Find out the output of the first node in the sequence
            new_output = first_node['outputs'][output_idx]
            assert new_output in graph_converter.tensor_map

            # For each node that is next of the last node, we connect it with the first node
            # Also, the replace the tensors when needed
            graph_converter.replace_next_tensors(last_node, first_node, new_output)
        else:

            # Find out the output of the last node in the sequence
            new_output = last_node['outputs'][0]
            assert new_output in graph_converter.tensor_map

            # For each node that is next of the last node, we connect it with the first node
            graph_converter.connect_next_tensors(last_node, first_node, new_output)

            # Update graph, prepare to drop the output tensor of the intermediate nodes and use the output tensor of
            # the last node instead
            first_node['outputs'][output_idx] = new_output
            if first_node['op'] is not None:
                first_node['op'].outputs[output_idx] = graph_converter.tensor_map[new_output]
            graph_converter.tensor_node_map[new_output] = first_node['name']

        # When the first node is a constant node, we need to set the buffer back
        if first_node['node_type'] == ExtendedOperator.CONSTANT_NODE and not use_forward_input:
            if seq[0]['node_type'] == ExtendedOperator.CONSTANT_NODE:
                old_tensor = graph_converter.tensor_map[seq[0]['name']]
            else:
                old_tensor = seq[0]['op'].inputs[input_idx]
            new_tensor = seq[-1]['op'].outputs[0]
            new_tensor.buffer = old_tensor.buffer

        if remove_first and remove_last:
            # Push the sequence to the removing list
            remove_ids.extend([x.index for x in seq])
        else:
            # Collect actions when removing the first node
            start_index = 0
            end_index = len(seq)
            if not remove_first:
                start_index = 1
                if remove_first_node_action is not None:
                    action = remove_first_node_action(first_node, last_node, custom_data)
                    if action is not None:
                        actions.extend(action)

            if not remove_last:
                end_index = len(seq) - 1
                if remove_last_node_action is not None:
                    action = remove_last_node_action(first_node, last_node, custom_data_last)
                    if action is not None:
                        actions.extend(action)

            # Push the sequence (except the first node) to the removing list
            remove_ids.extend([x.index for x in seq[start_index:end_index]])

    for func, args in actions:
        func(*args)

    graph_converter.graph.delete_vertices(remove_ids)


def expand_op_outputs_in_branches(
    nodes: typing.List[ig.Vertex],
    new_op_func: typing.Callable[[ig.Vertex, ig.Vertex, ig.Vertex], None],
    graph_converter: CommonGraph,
):
    actions = []
    for node in nodes:
        preserve_node = None
        prev_node_name = node['op'].inputs[0].name
        prev_node = graph_converter.graph.vs.find(name=graph_converter.tensor_node_map[prev_node_name])

        # Collect next nodes and choose one to preserve
        next_nodes = []
        for edge in node.out_edges():
            next_node = graph_converter.graph.vs[edge.target]
            if preserve_node is None or next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                preserve_node = next_node
            next_nodes.append(next_node)

        # For the filtered nodes, use the cloned op as the previous op
        filtered_nodes = list(set(next_nodes) - set([preserve_node]))
        for next_node in filtered_nodes:
            actions.extend(new_op_func(node, prev_node, next_node))

    # Process actions
    for func, args in actions:
        node = args[0]
        func(*args)


def get_same_padding_args(input_shape, filter_shape, strides, dilation):
    dim = len(input_shape)
    padding = [0] * dim

    for i in range(dim):
        if input_shape[i] % strides[i] == 0:
            padding[i] = max(1 - strides[i] + (filter_shape[i] - 1) * dilation[i], 0)
        else:
            padding[i] = max(1 + (filter_shape[i] - 1) * dilation[i] - (input_shape[i] % strides[i]), 0)

    pad_args = [[0, 0]] + [[x // 2, x - x // 2] for x in padding] + [[0, 0]]

    return pad_args
