def elementwise_op_transpose_passthrough_pass()

in tinynn/converter/operators/optimize.py [0:0]


    def elementwise_op_transpose_passthrough_pass(self, quantizable_ops_only: bool = False) -> int:
        edges = self.graph.graph.es.select(
            functools.partial(
                is_transpose_elementwise_op_edge,
                graph_converter=self.graph.graph,
                quantizable_ops_only=quantizable_ops_only,
            )
        )

        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        if quantizable_ops_only:
            all_edges = self.graph.graph.es.select(
                functools.partial(
                    is_transpose_elementwise_op_edge,
                    graph_converter=self.graph.graph,
                    quantizable_ops_only=False,
                )
            )

            all_pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in all_edges)

            forward_d = dict(all_pairs)
            backward_d = {v: k for k, v in forward_d.items()}

            filtered_nodes = []
            for s, e in pairs:
                if s['node_type'] == ExtendedOperator.TRANSPOSE:
                    pn = backward_d.get(s, None)
                    if pn is not None:
                        filtered_nodes.append(pn)
                    else:
                        log.warning('Cannot passthrough transpose upward around requantizable ops')
                else:
                    pn = forward_d.get(e, None)
                    if pn is not None:
                        filtered_nodes.append(pn)
                    else:
                        log.warning('Cannot passthrough transpose downward around requantizable ops')
        else:
            filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        num_actions = 0
        for node in unique_nodes:
            op = node['op']
            input_indices = op_input_indices(op)

            prev_nodes = []
            cand_perms = dict()
            prev_output_indices = []
            num_constant_nodes = 0
            num_reshape_transpose = 0
            prev_hints = set()
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    perm = tuple(prev_node['op'].inputs[1].tensor.tolist())

                    if node['node_type'] == ExtendedOperator.PACK:
                        perm = [i if i < op.axis else i + 1 for i in perm]
                        perm.insert(op.axis, op.axis)
                        perm = tuple(perm)

                    cand_perms.setdefault(perm, 0)
                    cand_perms[perm] += 1
                    if 'direction' in prev_node['op'].extra_hints:
                        prev_hints.add(prev_node['op'].extra_hints['direction'])

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    num_constant_nodes += 1

                if prev_node['node_type'] == ExtendedOperator.RESHAPE:
                    prev_prev_node_name = self.graph.tensor_node_map[prev_node['op'].inputs[0].name]
                    prev_prev_node = self.graph.graph.vs.find(name=prev_prev_node_name)
                    if prev_prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                        num_reshape_transpose += 1
                        if 'direction' in prev_prev_node['op'].extra_hints:
                            prev_hints.add(prev_prev_node['op'].extra_hints['direction'])

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
                continue

            next_nodes = []
            next_edges = []
            out_nodes = []
            skip_names = []
            next_hints = set()
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
                    skip_names.append(edge['label'])
                else:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())

                    if node['node_type'] == ExtendedOperator.UNPACK:
                        perm = [i if i < op.axis else i + 1 for i in perm]
                        perm.insert(op.axis, op.axis)
                        perm = tuple(perm)

                    cand_perms.setdefault(perm, 0)
                    cand_perms[perm] += 1
                    if 'direction' in next_node['op'].extra_hints:
                        next_hints.add(next_node['op'].extra_hints['direction'])

                if next_node['node_type'] == ExtendedOperator.RESHAPE:
                    o_nodes = [e.target_vertex for e in next_node.out_edges()]
                    if len(o_nodes) == 1 and o_nodes[0]['node_type'] == ExtendedOperator.TRANSPOSE:
                        num_reshape_transpose += 1
                        if 'direction' in o_nodes[0]['op'].extra_hints:
                            next_hints.add(o_nodes[0]['op'].extra_hints['direction'])

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
                continue

            cur_transpose_size = sum(cand_perms.values()) + num_reshape_transpose
            new_transpose_size = (
                len(prev_nodes) + len(next_nodes) - num_constant_nodes - cur_transpose_size + num_reshape_transpose
            )

            # Skip if the following conditions are met
            # a. the number of transpose nodes is not decreasing (skip if `bypass_elementwise_passthrough_constraint`)
            # b. no hint can be found (skip if optimize level is below BRANCH_OPTIMIZE_EXTENDED)
            is_increasing = new_transpose_size > cur_transpose_size
            is_not_decreasing = new_transpose_size >= cur_transpose_size
            is_same = new_transpose_size == cur_transpose_size
            if len(next_nodes) == 0:
                continue
            else:
                if self.bypass_elementwise_passthrough_constraint:
                    condition = is_not_decreasing
                else:
                    if is_increasing:
                        continue
                    condition = is_same

                if condition:
                    skip = True
                    if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
                        if 'down' in prev_hints or 'up' in next_hints:
                            skip = False
                    if skip:
                        continue

            num_actions += 1

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            perm = max(cand_perms.items(), key=lambda x: x[1])[0]
            perm_arr = np.array(perm, dtype='int32')
            inv_perm_arr = np.argsort(perm_arr).astype('int32')

            if node['node_type'] == ExtendedOperator.UNPACK:
                inv_perm_arr_post = inv_perm_arr[inv_perm_arr != op.axis]
                inv_perm_arr_post[inv_perm_arr_post > op.axis] -= 1

                perm_arr_post = np.argsort(inv_perm_arr_post).astype('int32')
            elif node['node_type'] == ExtendedOperator.PACK:
                perm_arr_post = perm_arr
                inv_perm_arr_post = inv_perm_arr

                perm_arr = perm_arr_post[perm_arr_post != op.axis]
                perm_arr[perm_arr > op.axis] -= 1

                inv_perm_arr = np.argsort(perm_arr).astype('int32')
            else:
                perm_arr_post = perm_arr
                inv_perm_arr_post = inv_perm_arr

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    perm_tensor = self.create_attr_tensor(inv_perm_arr)
                    if len(prev_out.shape) != perm_tensor.tensor.size:
                        new_shape = [1] * (perm_tensor.tensor.size - len(prev_out.shape)) + list(prev_out.shape)
                        prev_out_reshaped = self.create_transform_tensor(
                            np.reshape(prev_out.tensor, new_shape), quantization=prev_out.quantization
                        )
                        new_shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
                        self.graph.add_operator(
                            tfl.ReshapeOperator([prev_out, new_shape_tensor], [prev_out_reshaped], new_shape)
                        )
                        prev_out = prev_out_reshaped
                    prev_new_out = self.create_transform_tensor(
                        np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
                    )
                    tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                    transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
                    transpose_op.extra_hints['direction'] = 'up'
                    self.graph.add_operator(transpose_op)
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                # For unused tensors, we perform inplace shape updates
                if op_out.name in skip_names:
                    orig_shape = np.array(op_out.shape, dtype='int32')
                    new_shape = orig_shape[inv_perm_arr]
                    op_out.shape = tuple(new_shape.tolist())
                    continue

                perm_tensor = self.create_attr_tensor(perm_arr_post)
                new_out = self.create_transform_tensor(
                    np.transpose(op_out.tensor, inv_perm_arr_post), quantization=op_out.quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
                transpose_op.extra_hints['direction'] = 'down'
                self.graph.add_operator(transpose_op)

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

            # OP specific dim handling logic
            if node['node_type'] in (ExtendedOperator.CONCATENATION, ExtendedOperator.GATHER, ExtendedOperator.UNPACK):
                old_axis = op.axis
                new_axis = np.where(inv_perm_arr == old_axis)[0][0]
                op.axis = new_axis
            elif node['node_type'] == ExtendedOperator.PACK:
                old_axis = op.axis
                new_axis = np.where(inv_perm_arr_post == old_axis)[0][0]
                op.axis = new_axis
            elif node['node_type'] == ExtendedOperator.SPLIT_V:
                old_dim = op.inputs[2].tensor
                new_dim = np.where(inv_perm_arr == old_dim)[0][0]
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
            elif node['node_type'] == ExtendedOperator.SPLIT:
                old_dim = op.inputs[0].tensor
                new_dim = np.where(inv_perm_arr == old_dim)[0][0]
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
            elif node['node_type'] in (
                ExtendedOperator.PAD,
                ExtendedOperator.PADV2,
                ExtendedOperator.MIRROR_PAD,
                ExtendedOperator.TILE,
            ):
                old_pad = op.inputs[1].tensor
                new_pad = self.create_attr_tensor(old_pad[inv_perm_arr])
                actions.append((self.graph.replace_operator_input, (node, 1, new_pad, True)))
            elif node['node_type'] == ExtendedOperator.PRELU:
                old_weight = op.inputs[1].tensor
                if old_weight.ndim != 1:
                    assert old_weight.ndim + 1 == len(inv_perm_arr)
                    new_perm = np.argsort(np.argsort(inv_perm_arr[1:]))
                    new_perm_t = self.create_attr_tensor(np.array(new_perm, dtype='int32'))
                    new_weight = self.create_transform_tensor(np.transpose(old_weight, new_perm))
                    self.graph.add_operator(tfl.TransposeOperator([op.inputs[1], new_perm_t], [new_weight]))
                    actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
            elif node['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE):
                for i, t in enumerate(op.inputs[1:]):
                    if t.buffer is None:
                        new_perm_t = self.create_attr_tensor(np.array(inv_perm_arr, dtype='int32'))
                        new_t = self.create_transform_tensor(t.tensor[inv_perm_arr])
                        self.graph.add_operator(tfl.TransposeOperator([t, new_perm_t], [new_t]))
                    else:
                        new_t = self.create_attr_tensor(t.tensor[inv_perm_arr])
                    actions.append((self.graph.replace_operator_input, (node, i + 1, new_t, True)))
            elif node['node_type'] in (
                ExtendedOperator.SUM,
                ExtendedOperator.ARG_MIN,
                ExtendedOperator.ARG_MAX,
                ExtendedOperator.REDUCE_MIN,
                ExtendedOperator.REDUCE_MAX,
                ExtendedOperator.REDUCE_PROD,
                ExtendedOperator.MEAN,
            ):
                old_axis = op.inputs[1].tensor.tolist()
                new_axis = []
                for t in old_axis:
                    new_t = np.where(inv_perm_arr == t)[0][0]
                    new_axis.append(new_t)
                axis_arr = np.array(new_axis, dtype='int32')
                axis_tensor = self.create_attr_tensor(axis_arr)
                actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

        return num_actions