def group_deconv_rewrite_pass()

in tinynn/converter/operators/optimize.py [0:0]


    def group_deconv_rewrite_pass(self):
        vertices = self.graph.graph.vs.select(functools.partial(is_group_deconv_node, graph_converter=self.graph.graph))

        remove_ids = []
        ops = []
        restore_mapping = []
        for conv in vertices:
            restore_nodes = []
            # For each node that is next of a transformable node,
            #  a. if it is an output node, remove it anyway since it will always be reconstructed
            #  b. otherwise, record the info of the edge so that we may restore it after reconstruction
            for out_edge in conv.out_edges():
                next_node = self.graph.graph.vs[out_edge.target]
                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    remove_ids.append(next_node.index)
                    del self.graph.tensor_map[next_node['outputs'][0]]
                    del self.graph.tensor_node_map[next_node['outputs'][0]]
                else:
                    restore_nodes.append((out_edge['name'], next_node['name']))

            # Remove the mapping since they are going to be removed
            for output_name in conv['outputs']:
                del self.graph.tensor_map[output_name]
                del self.graph.tensor_node_map[output_name]

            restore_mapping.append(restore_nodes)
            remove_ids.append(conv.index)

        # Make sure the nodes are topologically sorted
        sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]

        # Delete nodes before transformation in the graph
        self.graph.graph.delete_vertices(remove_ids)

        for conv, mapping in zip(sorted_ops, restore_mapping):
            input_tensor = conv.inputs[2]
            weight_tensor = conv.inputs[1]
            output_shape_tensor = conv.inputs[0]
            bias_tensor = conv.inputs[3] if len(conv.inputs) > 3 else None
            output_tensor = conv.outputs[0]

            num_output_channel = output_tensor.shape[3]
            num_weight_channel = weight_tensor.shape[0]
            num_chunks = num_output_channel // num_weight_channel

            ops = []

            input_tensors = [
                self.create_transform_tensor(arr, quantization=input_tensor.quantization)
                for arr in np.split(input_tensor.tensor, num_chunks, 3)
            ]
            output_tensors = [
                self.create_transform_tensor(arr, quantization=output_tensor.quantization)
                for arr in np.split(output_tensor.tensor, num_chunks, 3)
            ]
            weights = [
                self.create_attr_tensor(arr, quantization=weight_tensor.quantization)
                for arr in np.split(weight_tensor.tensor, num_chunks, 3)
            ]

            if bias_tensor is not None:
                biases = [
                    self.create_attr_tensor(arr, quantization=bias_tensor.quantization)
                    for arr in np.split(bias_tensor.tensor, num_chunks, 0)
                ]
            else:
                biases = [None] * num_chunks

            new_os = output_shape_tensor.tensor.copy()
            new_os[3] = num_weight_channel
            new_ost = self.create_attr_tensor(new_os)
            dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
            ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))

            for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
                inputs = [new_ost, w, it]
                if b is not None:
                    inputs.append(b)
                ops.append(
                    tfl.TransposeConvOperator(
                        inputs,
                        [ot],
                        padding=conv.padding,
                        strideH=conv.strideH,
                        strideW=conv.strideW,
                    )
                )

            ops.append(tfl.ConcatenationOperator(output_tensors, [output_tensor], 3))

            for op in ops:
                self.graph.add_operator(op, transform=True)

            self.graph.try_restore_edges(mapping)