def output_transpose_pass()

in tinynn/converter/operators/optimize.py [0:0]


    def output_transpose_pass(self):
        nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32')
        nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32')

        if isinstance(self.graph.output_transpose, (list, tuple)):
            assert len(self.graph.output_transpose) == len(self.graph.outputs)
        else:
            self.graph.output_transpose = [self.graph.output_transpose] * len(self.graph.outputs)

        filtered_dict = {}
        for i, (name, transpose) in enumerate(zip(self.graph.outputs, self.graph.output_transpose)):
            if name in filtered_dict:
                old_transpose = filtered_dict[name]
                assert (
                    transpose == old_transpose
                ), f"outputs {i} points to an exising tensor {name}, but their property `output_transpose` is different"
            else:
                filtered_dict[name] = transpose

        prev_modify_node_indices = {}
        prev_modify_next_indices = {}
        next_modify_node_indices = {}
        for name, transpose in filtered_dict.items():
            if name in self.graph.tensor_map:
                tensor = self.graph.tensor_map[name]
                if transpose is None:
                    transpose = len(tensor.shape) == 4
            else:
                transpose = False

            for i, n in enumerate(self.graph.outputs):
                if name == n:
                    self.graph.output_transpose[i] = transpose

            if transpose:
                node_name = self.graph.tensor_node_map[name]
                node = self.graph.graph.vs.find(name=node_name)
                tensor_idx = node['outputs'].index(name)

                prev_node = None
                if node['node_type'] == ExtendedOperator.DEQUANTIZE:
                    prev_node_name = self.graph.tensor_node_map[node['op'].inputs[0].name]
                    prev_node = self.graph.graph.vs.find(name=prev_node_name)

                if prev_node is None:
                    next_modify_node_indices.setdefault(node, set())
                    next_modify_node_indices[node].add(tensor_idx)
                else:
                    prev_modify_node_indices.setdefault(node, set())
                    prev_modify_node_indices[node].add(0)
                    prev_modify_next_indices.setdefault(node, set())
                    prev_modify_next_indices[node].add(tensor_idx)

        remove_edges = []
        remove_vertices = []
        actions = []
        for node, index in prev_modify_node_indices.items():
            next_indices = prev_modify_next_indices[node]
            op = node['op']
            tensor_names = [node['outputs'][i] for i in index]

            next_nodes = {}
            for edge in node.out_edges():
                if edge['label'] not in tensor_names:
                    continue

                if edge.index in remove_edges:
                    continue

                tensor_idx = tensor_names.index(edge['label'])
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] not in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
                    next_nodes.setdefault(tensor_idx, [])
                    next_nodes[tensor_idx].append(next_node)

            prev_nodes = []
            prev_output_indices = []
            for i in index:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, index, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
                    prev_new_out = self.create_transform_tensor(
                        np.transpose(prev_out.tensor, nchw2nhwc_perm), quantization=prev_out.quantization
                    )
                    tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                    prev_transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
                    prev_transpose_op.extra_hints['direction'] = 'up'
                    self.graph.add_operator(prev_transpose_op)
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

            tensor_mapping = {}
            for i in next_indices:
                t = op.outputs[i]
                t.tensor = np.transpose(t.tensor, nchw2nhwc_perm)
                t.shape = t.tensor.shape

                if i in next_nodes:
                    new_t = self.create_transform_tensor(np.transpose(t.tensor, nhwc2nchw_perm))
                    perm_t = self.create_attr_tensor(nhwc2nchw_perm)

                    next_transpose_op = tfl.TransposeOperator([t, perm_t], [new_t])
                    next_transpose_op.extra_hints['direction'] = 'down'
                    self.graph.add_operator(next_transpose_op)

                    tensor_mapping[t.name] = new_t

            for nodes in next_nodes.values():
                for n in nodes:
                    next_op = n['op']
                    for i, t in enumerate(next_op.inputs):
                        if t.name in tensor_mapping:
                            actions.append((self.graph.replace_operator_input, (n, i, tensor_mapping[t.name])))

        for node, index in next_modify_node_indices.items():
            op = node['op']
            tensor_names = [node['outputs'][i] for i in index]
            out_nodes = []
            next_nodes = []
            next_edges = []
            for edge in node.out_edges():
                if edge['label'] not in tensor_names:
                    continue

                if edge.index in remove_edges:
                    continue

                next_node = self.graph.graph.vs[edge.target]
                tensor_idx = tensor_names.index(edge['label'])

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                elif next_node['node_type'] != ExtendedOperator.UNUSED_NODE:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

            remove_vertices.extend([x.index for x in out_nodes])
            remove_edges.extend([x.index for x in next_edges])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                if i not in index:
                    continue

                op_out.tensor = np.transpose(op_out.tensor, nchw2nhwc_perm)
                op_out.shape = op_out.tensor.shape

                perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
                new_out = self.create_transform_tensor(
                    np.transpose(op_out.tensor, nhwc2nchw_perm), quantization=op_out.quantization
                )

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                next_transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
                next_transpose_op.extra_hints['direction'] = 'up'
                self.graph.add_operator(next_transpose_op)

                tensor_node_dict[op_out.name] = (
                    self.graph.graph.vs.find(name=self.graph.tensor_node_map[new_out.name]),
                    new_out.name,
                )

            # Connect next edges and replace next tensors
            for edge in next_edges:
                old_name = edge['name']
                source, new_name = tensor_node_dict[old_name]
                target = edge.target_vertex
                self.graph.graph.add_edge(source, target, name=new_name, label=new_name)

                op = target['op']
                for i, op_input in enumerate(op.inputs):
                    if op_input.name == old_name:
                        op.inputs[i] = self.graph.tensor_map[new_name]
                        break

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)