def elementwise_op_reshape_passthrough_pass()

in tinynn/converter/operators/optimize.py [0:0]


    def elementwise_op_reshape_passthrough_pass(self) -> int:
        edges = self.graph.graph.es.select(
            functools.partial(is_reshape_elementwise_op_edge, graph_converter=self.graph.graph)
        )
        pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
        filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.RESHAPE else k[1] for k in pairs)
        unique_nodes = list(set(filtered_nodes))

        actions = []
        remove_edges = []
        remove_vertices = []
        num_actions = 0
        for node in unique_nodes:
            op = node['op']
            dim_indice = op_input_dims(op)
            input_indices = op_input_indices(op)

            prev_nodes = []
            cand_shapes = dict()
            cand_next_shapes = dict()
            prev_output_indices = []
            num_constant_nodes = 0
            prev_hints = set()
            for i in input_indices:
                prev_node_name = op.inputs[i].name
                prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
                prev_nodes.append(prev_node)
                prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

                if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
                    num_constant_nodes += 1

                if prev_node['node_type'] == ExtendedOperator.RESHAPE:
                    mapping = dict()
                    if not is_simple_reshape(
                        prev_node['op'].inputs[0].shape, prev_node['op'].outputs[0].shape, mapping
                    ):
                        continue

                    new_dim = None
                    if dim_indice is not None:
                        rev_mapping = {v: k for k, v in mapping.items()}
                        if node['node_type'] == ExtendedOperator.PACK:
                            if dim_indice in rev_mapping:
                                tmp_new_dim = rev_mapping[dim_indice]
                            else:
                                if dim_indice - 1 in rev_mapping:
                                    tmp_new_dim = rev_mapping[dim_indice - 1] + 1
                                elif dim_indice + 1 in rev_mapping:
                                    tmp_new_dim = rev_mapping[dim_indice + 1] - 1
                                else:
                                    # TODO: Figure out the rev index
                                    tmp_new_dim = -1
                            tmp_dim_indice = dim_indice
                            new_dim = -1
                            dim_indice = -1
                        else:
                            if dim_indice not in rev_mapping:
                                continue
                            new_dim = rev_mapping[dim_indice]

                    shape = tuple(prev_node['op'].inputs[0].shape)
                    shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
                    if node['node_type'] == ExtendedOperator.PACK and tmp_new_dim >= 0:
                        shape = list(shape)
                        shape.insert(tmp_new_dim, -1)
                        shape = tuple(shape)
                    cand_shapes.setdefault(shape, 0)
                    cand_shapes[shape] += 1

                    next_shape = tuple(prev_node['op'].outputs[0].shape)
                    next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
                    if node['node_type'] == ExtendedOperator.PACK:
                        next_shape = list(next_shape)
                        next_shape.insert(tmp_dim_indice, -1)
                        next_shape = tuple(next_shape)
                    cand_next_shapes.setdefault(next_shape, 0)
                    cand_next_shapes[next_shape] += 1

                    if node['node_type'] == ExtendedOperator.PACK:
                        dim_indice = tmp_dim_indice

                    if 'direction' in prev_node['op'].extra_hints:
                        prev_hints.add(prev_node['op'].extra_hints['direction'])

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
                continue

            next_nodes = []
            next_edges = []
            out_nodes = []
            skip_names = []
            next_hints = set()
            for edge in node.out_edges():
                if edge.index in remove_edges:
                    continue
                next_node = self.graph.graph.vs[edge.target]

                if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
                    out_nodes.append(next_node)
                elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
                    skip_names.append(edge['label'])
                else:
                    next_nodes.append(next_node)
                    next_edges.append(edge)

                if next_node['node_type'] == ExtendedOperator.RESHAPE:
                    mapping = dict()
                    if not is_simple_reshape(
                        next_node['op'].inputs[0].shape, next_node['op'].outputs[0].shape, mapping
                    ):
                        continue

                    new_dim = None
                    if dim_indice is not None:
                        if node['node_type'] == ExtendedOperator.UNPACK:
                            if dim_indice in mapping:
                                tmp_new_dim = mapping[dim_indice]
                            else:
                                if dim_indice - 1 in mapping:
                                    tmp_new_dim = mapping[dim_indice - 1] + 1
                                elif dim_indice + 1 in mapping:
                                    tmp_new_dim = mapping[dim_indice + 1] - 1
                                else:
                                    # TODO: Figure out the rev index
                                    tmp_new_dim = -1
                            tmp_dim_indice = dim_indice
                            new_dim = -1
                            dim_indice = -1
                        else:
                            if dim_indice not in mapping:
                                continue
                            new_dim = mapping[dim_indice]

                    shape = tuple(next_node['op'].outputs[0].shape)
                    shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
                    if node['node_type'] == ExtendedOperator.UNPACK and tmp_new_dim >= 0:
                        shape = list(shape)
                        shape.insert(tmp_new_dim, -1)
                        shape = tuple(shape)
                    cand_shapes.setdefault(shape, 0)
                    cand_shapes[shape] += 1

                    next_shape = tuple(next_node['op'].inputs[0].shape)
                    next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
                    if node['node_type'] == ExtendedOperator.UNPACK:
                        next_shape = list(next_shape)
                        next_shape.insert(tmp_dim_indice, -1)
                        next_shape = tuple(next_shape)
                    cand_next_shapes.setdefault(next_shape, 0)
                    cand_next_shapes[next_shape] += 1

                    if node['node_type'] == ExtendedOperator.UNPACK:
                        dim_indice = tmp_dim_indice

                    if 'direction' in next_node['op'].extra_hints:
                        next_hints.add(next_node['op'].extra_hints['direction'])

            if len(cand_shapes) == 0:
                continue

            if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
                continue

            cur_reshape_size = max(cand_shapes.values())
            cur_next_reshape_size = max(cand_next_shapes.values())
            full_size = len(prev_nodes) + len(next_nodes)

            if cur_reshape_size != cur_next_reshape_size:
                continue

            new_reshape_size = full_size - cur_reshape_size - num_constant_nodes

            # Skip if not wrapped by reshapes
            if (
                len(next_nodes) == 0 or new_reshape_size > cur_reshape_size
            ):  # cur_reshape_size < full_size or cur_next_reshape_size < full_size:
                continue
            elif new_reshape_size == cur_reshape_size:
                skip = True
                if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
                    if 'down' in prev_hints or 'up' in next_hints:
                        skip = False
                if skip:
                    continue

            num_actions += 1

            remove_edges.extend([x.index for x in next_edges])
            remove_vertices.extend([x.index for x in out_nodes])

            for n in out_nodes:
                del self.graph.tensor_map[n['outputs'][0]]
                del self.graph.tensor_node_map[n['outputs'][0]]

            prev_shape = max(cand_shapes.items(), key=lambda x: x[1])[0]
            next_shape = max(cand_next_shapes.items(), key=lambda x: x[1])[0]

            tensor_node_dict = {}
            for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
                if prev_node['op'] is None:
                    prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
                else:
                    prev_out = prev_node['op'].outputs[next_idx]
                if prev_out.name in tensor_node_dict:
                    prev_new_out, skip = tensor_node_dict[prev_out.name]
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
                    skip += 1
                    tensor_node_dict[prev_out.name] = (prev_new_out, skip)
                else:
                    if node['node_type'] == ExtendedOperator.PACK:
                        tmp_prev_shape = prev_shape
                        prev_shape = [i for i in prev_shape if i != -1]
                    prev_shape_aligned = prev_shape
                    if np.prod(prev_out.shape) != np.prod(prev_shape):
                        new_prev_shape = prev_out.shape
                        if len(prev_out.shape) < len(next_shape):
                            new_prev_shape = [1] * (len(next_shape) - len(prev_out.shape)) + list(prev_out.shape)
                        mapping = {}
                        is_simple_reshape(prev_shape, next_shape, mapping)
                        prev_shape_aligned = np.ones(len(prev_shape), dtype='int32')
                        for pi, ni in mapping.items():
                            prev_shape_aligned[pi] = new_prev_shape[ni]

                    prev_new_out = self.create_transform_tensor(
                        np.reshape(prev_out.tensor, prev_shape_aligned), quantization=prev_out.quantization
                    )
                    tensor_node_dict[prev_out.name] = (prev_new_out, 1)
                    shape_tensor = self.create_attr_tensor(np.array(prev_new_out.shape, dtype='int32'))
                    reshape_op = tfl.ReshapeOperator(
                        [prev_out, shape_tensor], [prev_new_out], newShape=shape_tensor.tensor
                    )
                    reshape_op.extra_hints['direction'] = 'up'
                    self.graph.add_operator(reshape_op)
                    actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))

                    if node['node_type'] == ExtendedOperator.PACK:
                        prev_shape = tmp_prev_shape

            tensor_node_dict = {}
            for i, op_out in enumerate(op.outputs):
                if node['node_type'] == ExtendedOperator.UNPACK:
                    tmp_prev_shape = prev_shape
                    prev_shape = [i for i in prev_shape if i != -1]

                # For unused tensors, we perform inplace shape updates
                if op_out.name in skip_names:
                    new_shape = np.reshape(op_out.tensor, prev_shape).shape
                    op_out.shape = tuple(new_shape)

                    if node['node_type'] == ExtendedOperator.UNPACK:
                        prev_shape = tmp_prev_shape

                    continue

                new_out = self.create_transform_tensor(
                    np.reshape(op_out.tensor, prev_shape), quantization=op_out.quantization
                )
                shape_tensor = self.create_attr_tensor(np.array(op_out.shape, dtype='int32'))

                # Update relations
                if op_out.name in self.graph.tensor_node_map:
                    del self.graph.tensor_node_map[op_out.name]
                self.graph.tensor_node_map[new_out.name] = node['name']
                self.graph.tensor_map[new_out.name] = new_out
                node['outputs'][i] = new_out.name
                op.outputs[i] = new_out

                reshape_op = tfl.ReshapeOperator([new_out, shape_tensor], [op_out], shape_tensor.tensor)
                reshape_op.extra_hints['direction'] = 'down'
                self.graph.add_operator(reshape_op)

                tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

                if node['node_type'] == ExtendedOperator.UNPACK:
                    prev_shape = tmp_prev_shape

            # OP specific dim handling logic
            if node['node_type'] in (
                ExtendedOperator.CONCATENATION,
                ExtendedOperator.GATHER,
                ExtendedOperator.UNPACK,
                ExtendedOperator.PACK,
            ):
                new_axis = prev_shape.index(-1)
                op.axis = new_axis
            elif node['node_type'] == ExtendedOperator.SPLIT_V:
                new_dim = prev_shape.index(-1)
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
            elif node['node_type'] == ExtendedOperator.SPLIT:
                new_dim = prev_shape.index(-1)
                new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
                actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
            elif node['node_type'] in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD):
                old_pad = op.inputs[1].tensor
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)
                new_pad = np.zeros((len(prev_shape), 2), dtype='int32')
                new_pad[new_dim, :] = old_pad[old_dim, :]
                new_pad_tensor = self.create_attr_tensor(new_pad)
                actions.append((self.graph.replace_operator_input, (node, 1, new_pad_tensor, True)))
            elif node['node_type'] == ExtendedOperator.PRELU:
                old_weight = op.inputs[1].tensor
                if old_weight.ndim != 1:
                    new_dim = prev_shape.index(-1)
                    old_dim = next_shape.index(-1)
                    new_shape = np.ones(len(prev_shape) - 1, dtype='int32')
                    new_shape[new_dim - 1] = old_weight.shape[old_dim - 1]
                    new_shape_t = self.create_attr_tensor(new_shape)
                    new_weight = self.create_transform_tensor(np.reshape(old_weight, new_shape))
                    self.graph.add_operator(tfl.ReshapeOperator([op.inputs[1], new_shape_t], [new_weight], new_shape))
                    actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
            elif node['node_type'] == ExtendedOperator.SLICE:
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)

                new_start = np.zeros(len(prev_shape), dtype='int32')
                new_start[new_dim] = op.inputs[1].tensor[old_dim]
                new_start_t = self.create_attr_tensor(new_start)

                new_size = np.array(prev_shape, dtype='int32')
                new_size[new_dim] = op.inputs[2].tensor[old_dim]
                new_size_t = self.create_attr_tensor(new_size)

                actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
                actions.append((self.graph.replace_operator_input, (node, 2, new_size_t, True)))
            elif node['node_type'] == ExtendedOperator.STRIDED_SLICE:
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)

                new_start = np.zeros(len(prev_shape), dtype='int32')
                new_start[new_dim] = op.inputs[1].tensor[old_dim]
                if op.inputs[1].buffer is None:
                    new_start_t = self.create_transform_tensor(new_start)
                    starts_mid = new_start[new_dim : new_dim + 1]
                    starts_mid_tensor = self.create_transform_tensor(starts_mid)

                    slice_inputs = [
                        op.inputs[1],
                        self.create_attr_tensor(np.array([old_dim], dtype='int32')),
                        self.create_attr_tensor(np.array([1], dtype='int32')),
                    ]

                    self.graph.add_operator(tfl.SliceOperator(slice_inputs, [starts_mid_tensor]))

                    starts_left = new_start[:new_dim]
                    starts_right = new_start[new_dim + 1 :]
                    starts_tensors = []
                    if len(starts_left) > 0:
                        starts_tensors.append(self.create_attr_tensor(starts_left))
                    starts_tensors.append(starts_mid_tensor)
                    if len(starts_right) > 0:
                        starts_tensors.append(self.create_attr_tensor(starts_right))
                    if len(starts_tensors) > 1:
                        self.graph.add_operator(tfl.ConcatenationOperator(starts_tensors, [new_start_t], 0))
                    else:
                        new_start_t = starts_tensors[0]
                else:
                    new_start_t = self.create_attr_tensor(new_start)

                new_end = np.array(prev_shape, dtype='int32')
                new_end[new_dim] = op.inputs[2].tensor[old_dim]
                if op.inputs[2].buffer is None:
                    new_end_t = self.create_transform_tensor(new_end)
                    ends_mid = new_end[new_dim : new_dim + 1]
                    ends_mid_tensor = self.create_transform_tensor(ends_mid)

                    slice_inputs = [
                        op.inputs[2],
                        self.create_attr_tensor(np.array([old_dim], dtype='int32')),
                        self.create_attr_tensor(np.array([1], dtype='int32')),
                    ]

                    self.graph.add_operator(tfl.SliceOperator(slice_inputs, [ends_mid_tensor]))

                    ends_left = new_end[:new_dim]
                    ends_right = new_end[new_dim + 1 :]
                    ends_tensors = []
                    if len(ends_left) > 0:
                        ends_tensors.append(self.create_attr_tensor(ends_left))
                    ends_tensors.append(ends_mid_tensor)
                    if len(ends_right) > 0:
                        ends_tensors.append(self.create_attr_tensor(ends_right))
                    if len(ends_tensors) > 1:
                        self.graph.add_operator(tfl.ConcatenationOperator(ends_tensors, [new_end_t], 0))
                    else:
                        new_end_t = ends_tensors[0]
                else:
                    new_end_t = self.create_attr_tensor(new_end)

                new_stride = np.ones(len(prev_shape), dtype='int32')
                new_stride[new_dim] = op.inputs[3].tensor[old_dim]
                new_stride_t = self.create_attr_tensor(new_stride)

                actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
                actions.append((self.graph.replace_operator_input, (node, 2, new_end_t, True)))
                actions.append((self.graph.replace_operator_input, (node, 3, new_stride_t, True)))
            elif node['node_type'] == ExtendedOperator.TILE:
                old_shape = op.inputs[1].tensor
                new_dim = prev_shape.index(-1)
                old_dim = next_shape.index(-1)
                new_shape = np.ones(len(prev_shape), dtype='int32')
                new_shape[new_dim] = old_shape[old_dim]
                new_shape_tensor = self.create_attr_tensor(new_shape)
                actions.append((self.graph.replace_operator_input, (node, 1, new_shape_tensor, True)))
            elif node['node_type'] in (
                ExtendedOperator.SUM,
                ExtendedOperator.ARG_MIN,
                ExtendedOperator.ARG_MAX,
                ExtendedOperator.REDUCE_MIN,
                ExtendedOperator.REDUCE_MAX,
                ExtendedOperator.REDUCE_PROD,
                ExtendedOperator.MEAN,
            ):
                new_axis = prev_shape.index(-1)
                axis_arr = np.array([new_axis], dtype='int32')
                axis_tensor = self.create_attr_tensor(axis_arr)
                actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))
            elif dim_indice is not None:
                raise NotImplementedError(f'{node["node_type"]} has the property `dims` but is not handled')

            for edge in next_edges:
                source = tensor_node_dict[edge['name']]
                self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])

        # Process actions
        ids = []
        for func, args in actions:
            node = args[0]
            res = func(*args)
            if res is not None:
                ids.extend(res)

        remove_edges = list(set(remove_edges + ids))

        self.graph.graph.delete_edges(remove_edges)
        self.graph.graph.delete_vertices(remove_vertices)

        return num_actions