def elementwise_reshape_transpose_passthrough_pass()

in tinynn/converter/operators/optimize.py [0:0]


    def elementwise_reshape_transpose_passthrough_pass(self) -> int:
        edges = self.graph.graph.es.select(
            functools.partial(is_transpose_reshape_op_edge, graph_converter=self.graph.graph)
        )
        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        processed_nodes = set()
        num_actions = 0
        for node in unique_nodes:
            pending_processed_nodes = set()

            op = node['op']
            input_indices = op_input_indices(op)
            l_shape = op.inputs[0].shape
            r_shape = op.outputs[0].shape
            if len(l_shape) == 0 or len(r_shape) == 0:
                continue
            l_map, r_map, _, _ = reshape_mapping(l_shape, r_shape)
            mode = None
            need_chain = False
            for l_val, r_val in zip(l_map, r_map):
                if len(l_val) > 1 and len(r_val) == 1:
                    if mode in (None, 'up'):
                        mode = 'up'
                    else:
                        mode = '?'
                        break
                elif len(r_val) > 1 and len(l_val) == 1:
                    if mode in (None, 'down'):
                        mode = 'down'
                    else:
                        mode = '?'
                        break
                elif len(r_val) > 1 and len(l_val) > 1:
                    if len(r_val) != len(l_val) or r_val != l_val:
                        # TODO: Support this case
                        mode = '?'
                        break
                    else:
                        need_chain = True

            if mode is None:
                mode = 'down'

            # TODO: Support multi-multi mappings
            if mode == '?':
                # reset hints if passthrough is not possible
                for i in input_indices:
                    prev_node_name = op.inputs[i].name
                    prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                    if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                        if 'direction' in prev_node['op'].extra_hints:
                            prev_node['op'].extra_hints.pop('direction')
                for edge in node.out_edges():
                    if edge.index in remove_edges:
                        continue
                    next_node = self.graph.graph.vs[edge.target]

                    if 'direction' in next_node['op'].extra_hints:
                        next_node['op'].extra_hints.pop('direction')
                continue

            check_consecutive_indices = []
            if need_chain:
                new_l_map = []
                new_r_map = []
                for l_val, r_val in zip(l_map, r_map):
                    if len(l_val) > 1 and len(r_val) > 1:
                        if mode == 'down':
                            check_consecutive_indices.append(l_val)
                        else:
                            check_consecutive_indices.append(r_val)
                        for l_item in l_val:
                            new_l_map.append([l_item])
                        for r_item in r_val:
                            new_r_map.append([r_item])
                    else:
                        new_l_map.append(l_val)
                        new_r_map.append(r_val)

                l_map = new_l_map
                r_map = new_r_map

            prev_nodes = []
            cand_perms = dict()
            cand_rev_perms = dict()
            prev_output_indices = []
            num_constant_nodes = 0
            prev_hints = set()
            skip = False
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    if prev_node['name'] in processed_nodes:
                        skip = True
                        break
                    pending_processed_nodes.add(prev_node['name'])
                    if mode == 'down':
                        perm = tuple(prev_node['op'].inputs[1].tensor.tolist())
                        cand_perms.setdefault(perm, 0)
                        cand_perms[perm] += 1
                    elif mode == 'up':
                        perm = tuple(np.argsort(prev_node['op'].inputs[1].tensor).tolist())
                        cand_rev_perms.setdefault(perm, 0)
                        cand_rev_perms[perm] += 1
                    if 'direction' in prev_node['op'].extra_hints:
                        prev_hints.add(prev_node['op'].extra_hints['direction'])

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    num_constant_nodes += 1

            if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints):
                continue

            next_nodes = []
            next_edges = []
            out_nodes = []
            next_hints = set()
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                else:
                    if next_node['name'] in processed_nodes:
                        skip = True
                        break
                    pending_processed_nodes.add(next_node['name'])
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
                    if mode == 'down':
                        perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())
                        cand_rev_perms.setdefault(perm, 0)
                        cand_rev_perms[perm] += 1
                    elif mode == 'up':
                        perm = tuple(next_node['op'].inputs[1].tensor.tolist())
                        cand_perms.setdefault(perm, 0)
                        cand_perms[perm] += 1
                    if 'direction' in next_node['op'].extra_hints:
                        next_hints.add(next_node['op'].extra_hints['direction'])

            if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints):
                continue

            cur_transpose_size = sum(cand_perms.values()) + sum(cand_rev_perms.values())
            new_transpose_size = len(prev_nodes) + len(next_nodes) - sum(cand_perms.values()) - num_constant_nodes

            # Skip if the number of transpose nodes is not decreasing
            if len(cand_perms) == 0 or len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
                continue
            elif new_transpose_size == cur_transpose_size:
                skip = True
                if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
                    if 'down' in prev_hints or 'up' in next_hints:
                        skip = False

            if skip:
                continue

            perm = max(cand_perms.items(), key=lambda x: x[1])[0]
            perm_arr = np.array(perm, dtype='int32')

            skip = False
            for check_idx in check_consecutive_indices:
                if mode == 'down':
                    target_idx = perm_arr[check_idx]
                elif mode == 'up':
                    perm_sorter = perm_arr.argsort()
                    target_idx = perm_sorter[np.searchsorted(perm_arr, check_idx, sorter=perm_sorter)]
                normalized_src = [x - check_idx[0] for x in check_idx]
                normalized_tgt = [x - target_idx[0] for x in target_idx]
                if normalized_src != normalized_tgt:
                    skip = True
                    break

            if skip:
                continue

            num_actions += 1

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for pending_processed_node in pending_processed_nodes:
                processed_nodes.add(pending_processed_node)

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            if mode == 'down':
                inv_perm_arr = np.argsort(perm_arr).astype('int32')
                l_dict = dict(zip([x[0] for x in l_map], r_map))
                indices = map(lambda x: l_dict[x], inv_perm_arr.tolist())
                inv_post_perm = list(itertools.chain.from_iterable(indices))
                inv_post_perm_arr = np.array(inv_post_perm, dtype='int32')
                post_perm_arr = np.argsort(inv_post_perm_arr).astype('int32')
            elif mode == 'up':
                r_dict = dict(zip([x[0] for x in r_map], l_map))
                indices = map(lambda x: r_dict[x], perm)
                inv_perm = list(itertools.chain.from_iterable(indices))
                inv_perm_arr = np.array(inv_perm, dtype='int32')
                post_perm_arr = np.argsort(perm_arr).astype('int32')
                inv_post_perm_arr = np.argsort(post_perm_arr).astype('int32')

            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                perm_tensor = self.create_attr_tensor(inv_perm_arr)
                prev_new_out = self.create_transform_tensor(
                    np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
                )
                transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
                transpose_op.extra_hints['direction'] = 'up'
                self.graph.add_operator(transpose_op)
                actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                perm_tensor = self.create_attr_tensor(post_perm_arr)
                new_out = self.create_transform_tensor(
                    np.transpose(op_out.tensor, inv_post_perm_arr), quantization=op_out.quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
                transpose_op.extra_hints['direction'] = 'down'
                self.graph.add_operator(transpose_op)

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

            # OP specific dim handling logic
            old_shape = op.inputs[1].tensor
            new_shape = self.create_attr_tensor(old_shape[inv_post_perm_arr])
            actions.append((self.graph.replace_operator_input, (node, 1, new_shape, True)))
            op.newShape = new_shape.tensor

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

        return num_actions