in tinynn/converter/operators/optimize.py [0:0]
def elementwise_op_reshape_passthrough_pass(self) -> int:
edges = self.graph.graph.es.select(
functools.partial(is_reshape_elementwise_op_edge, graph_converter=self.graph.graph)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.RESHAPE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
num_actions = 0
for node in unique_nodes:
op = node['op']
dim_indice = op_input_dims(op)
input_indices = op_input_indices(op)
prev_nodes = []
cand_shapes = dict()
cand_next_shapes = dict()
prev_output_indices = []
num_constant_nodes = 0
prev_hints = set()
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
num_constant_nodes += 1
if prev_node['node_type'] == ExtendedOperator.RESHAPE:
mapping = dict()
if not is_simple_reshape(
prev_node['op'].inputs[0].shape, prev_node['op'].outputs[0].shape, mapping
):
continue
new_dim = None
if dim_indice is not None:
rev_mapping = {v: k for k, v in mapping.items()}
if node['node_type'] == ExtendedOperator.PACK:
if dim_indice in rev_mapping:
tmp_new_dim = rev_mapping[dim_indice]
else:
if dim_indice - 1 in rev_mapping:
tmp_new_dim = rev_mapping[dim_indice - 1] + 1
elif dim_indice + 1 in rev_mapping:
tmp_new_dim = rev_mapping[dim_indice + 1] - 1
else:
# TODO: Figure out the rev index
tmp_new_dim = -1
tmp_dim_indice = dim_indice
new_dim = -1
dim_indice = -1
else:
if dim_indice not in rev_mapping:
continue
new_dim = rev_mapping[dim_indice]
shape = tuple(prev_node['op'].inputs[0].shape)
shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
if node['node_type'] == ExtendedOperator.PACK and tmp_new_dim >= 0:
shape = list(shape)
shape.insert(tmp_new_dim, -1)
shape = tuple(shape)
cand_shapes.setdefault(shape, 0)
cand_shapes[shape] += 1
next_shape = tuple(prev_node['op'].outputs[0].shape)
next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
if node['node_type'] == ExtendedOperator.PACK:
next_shape = list(next_shape)
next_shape.insert(tmp_dim_indice, -1)
next_shape = tuple(next_shape)
cand_next_shapes.setdefault(next_shape, 0)
cand_next_shapes[next_shape] += 1
if node['node_type'] == ExtendedOperator.PACK:
dim_indice = tmp_dim_indice
if 'direction' in prev_node['op'].extra_hints:
prev_hints.add(prev_node['op'].extra_hints['direction'])
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
continue
next_nodes = []
next_edges = []
out_nodes = []
skip_names = []
next_hints = set()
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
skip_names.append(edge['label'])
else:
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.RESHAPE:
mapping = dict()
if not is_simple_reshape(
next_node['op'].inputs[0].shape, next_node['op'].outputs[0].shape, mapping
):
continue
new_dim = None
if dim_indice is not None:
if node['node_type'] == ExtendedOperator.UNPACK:
if dim_indice in mapping:
tmp_new_dim = mapping[dim_indice]
else:
if dim_indice - 1 in mapping:
tmp_new_dim = mapping[dim_indice - 1] + 1
elif dim_indice + 1 in mapping:
tmp_new_dim = mapping[dim_indice + 1] - 1
else:
# TODO: Figure out the rev index
tmp_new_dim = -1
tmp_dim_indice = dim_indice
new_dim = -1
dim_indice = -1
else:
if dim_indice not in mapping:
continue
new_dim = mapping[dim_indice]
shape = tuple(next_node['op'].outputs[0].shape)
shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
if node['node_type'] == ExtendedOperator.UNPACK and tmp_new_dim >= 0:
shape = list(shape)
shape.insert(tmp_new_dim, -1)
shape = tuple(shape)
cand_shapes.setdefault(shape, 0)
cand_shapes[shape] += 1
next_shape = tuple(next_node['op'].inputs[0].shape)
next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
if node['node_type'] == ExtendedOperator.UNPACK:
next_shape = list(next_shape)
next_shape.insert(tmp_dim_indice, -1)
next_shape = tuple(next_shape)
cand_next_shapes.setdefault(next_shape, 0)
cand_next_shapes[next_shape] += 1
if node['node_type'] == ExtendedOperator.UNPACK:
dim_indice = tmp_dim_indice
if 'direction' in next_node['op'].extra_hints:
next_hints.add(next_node['op'].extra_hints['direction'])
if len(cand_shapes) == 0:
continue
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
continue
cur_reshape_size = max(cand_shapes.values())
cur_next_reshape_size = max(cand_next_shapes.values())
full_size = len(prev_nodes) + len(next_nodes)
if cur_reshape_size != cur_next_reshape_size:
continue
new_reshape_size = full_size - cur_reshape_size - num_constant_nodes
# Skip if not wrapped by reshapes
if (
len(next_nodes) == 0 or new_reshape_size > cur_reshape_size
): # cur_reshape_size < full_size or cur_next_reshape_size < full_size:
continue
elif new_reshape_size == cur_reshape_size:
skip = True
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
if 'down' in prev_hints or 'up' in next_hints:
skip = False
if skip:
continue
num_actions += 1
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
prev_shape = max(cand_shapes.items(), key=lambda x: x[1])[0]
next_shape = max(cand_next_shapes.items(), key=lambda x: x[1])[0]
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
if node['node_type'] == ExtendedOperator.PACK:
tmp_prev_shape = prev_shape
prev_shape = [i for i in prev_shape if i != -1]
prev_shape_aligned = prev_shape
if np.prod(prev_out.shape) != np.prod(prev_shape):
new_prev_shape = prev_out.shape
if len(prev_out.shape) < len(next_shape):
new_prev_shape = [1] * (len(next_shape) - len(prev_out.shape)) + list(prev_out.shape)
mapping = {}
is_simple_reshape(prev_shape, next_shape, mapping)
prev_shape_aligned = np.ones(len(prev_shape), dtype='int32')
for pi, ni in mapping.items():
prev_shape_aligned[pi] = new_prev_shape[ni]
prev_new_out = self.create_transform_tensor(
np.reshape(prev_out.tensor, prev_shape_aligned), quantization=prev_out.quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
shape_tensor = self.create_attr_tensor(np.array(prev_new_out.shape, dtype='int32'))
reshape_op = tfl.ReshapeOperator(
[prev_out, shape_tensor], [prev_new_out], newShape=shape_tensor.tensor
)
reshape_op.extra_hints['direction'] = 'up'
self.graph.add_operator(reshape_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
if node['node_type'] == ExtendedOperator.PACK:
prev_shape = tmp_prev_shape
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
if node['node_type'] == ExtendedOperator.UNPACK:
tmp_prev_shape = prev_shape
prev_shape = [i for i in prev_shape if i != -1]
# For unused tensors, we perform inplace shape updates
if op_out.name in skip_names:
new_shape = np.reshape(op_out.tensor, prev_shape).shape
op_out.shape = tuple(new_shape)
if node['node_type'] == ExtendedOperator.UNPACK:
prev_shape = tmp_prev_shape
continue
new_out = self.create_transform_tensor(
np.reshape(op_out.tensor, prev_shape), quantization=op_out.quantization
)
shape_tensor = self.create_attr_tensor(np.array(op_out.shape, dtype='int32'))
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
reshape_op = tfl.ReshapeOperator([new_out, shape_tensor], [op_out], shape_tensor.tensor)
reshape_op.extra_hints['direction'] = 'down'
self.graph.add_operator(reshape_op)
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
if node['node_type'] == ExtendedOperator.UNPACK:
prev_shape = tmp_prev_shape
# OP specific dim handling logic
if node['node_type'] in (
ExtendedOperator.CONCATENATION,
ExtendedOperator.GATHER,
ExtendedOperator.UNPACK,
ExtendedOperator.PACK,
):
new_axis = prev_shape.index(-1)
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.SPLIT_V:
new_dim = prev_shape.index(-1)
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
elif node['node_type'] == ExtendedOperator.SPLIT:
new_dim = prev_shape.index(-1)
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
elif node['node_type'] in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD):
old_pad = op.inputs[1].tensor
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_pad = np.zeros((len(prev_shape), 2), dtype='int32')
new_pad[new_dim, :] = old_pad[old_dim, :]
new_pad_tensor = self.create_attr_tensor(new_pad)
actions.append((self.graph.replace_operator_input, (node, 1, new_pad_tensor, True)))
elif node['node_type'] == ExtendedOperator.PRELU:
old_weight = op.inputs[1].tensor
if old_weight.ndim != 1:
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_shape = np.ones(len(prev_shape) - 1, dtype='int32')
new_shape[new_dim - 1] = old_weight.shape[old_dim - 1]
new_shape_t = self.create_attr_tensor(new_shape)
new_weight = self.create_transform_tensor(np.reshape(old_weight, new_shape))
self.graph.add_operator(tfl.ReshapeOperator([op.inputs[1], new_shape_t], [new_weight], new_shape))
actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
elif node['node_type'] == ExtendedOperator.SLICE:
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_start = np.zeros(len(prev_shape), dtype='int32')
new_start[new_dim] = op.inputs[1].tensor[old_dim]
new_start_t = self.create_attr_tensor(new_start)
new_size = np.array(prev_shape, dtype='int32')
new_size[new_dim] = op.inputs[2].tensor[old_dim]
new_size_t = self.create_attr_tensor(new_size)
actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
actions.append((self.graph.replace_operator_input, (node, 2, new_size_t, True)))
elif node['node_type'] == ExtendedOperator.STRIDED_SLICE:
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_start = np.zeros(len(prev_shape), dtype='int32')
new_start[new_dim] = op.inputs[1].tensor[old_dim]
if op.inputs[1].buffer is None:
new_start_t = self.create_transform_tensor(new_start)
starts_mid = new_start[new_dim : new_dim + 1]
starts_mid_tensor = self.create_transform_tensor(starts_mid)
slice_inputs = [
op.inputs[1],
self.create_attr_tensor(np.array([old_dim], dtype='int32')),
self.create_attr_tensor(np.array([1], dtype='int32')),
]
self.graph.add_operator(tfl.SliceOperator(slice_inputs, [starts_mid_tensor]))
starts_left = new_start[:new_dim]
starts_right = new_start[new_dim + 1 :]
starts_tensors = []
if len(starts_left) > 0:
starts_tensors.append(self.create_attr_tensor(starts_left))
starts_tensors.append(starts_mid_tensor)
if len(starts_right) > 0:
starts_tensors.append(self.create_attr_tensor(starts_right))
if len(starts_tensors) > 1:
self.graph.add_operator(tfl.ConcatenationOperator(starts_tensors, [new_start_t], 0))
else:
new_start_t = starts_tensors[0]
else:
new_start_t = self.create_attr_tensor(new_start)
new_end = np.array(prev_shape, dtype='int32')
new_end[new_dim] = op.inputs[2].tensor[old_dim]
if op.inputs[2].buffer is None:
new_end_t = self.create_transform_tensor(new_end)
ends_mid = new_end[new_dim : new_dim + 1]
ends_mid_tensor = self.create_transform_tensor(ends_mid)
slice_inputs = [
op.inputs[2],
self.create_attr_tensor(np.array([old_dim], dtype='int32')),
self.create_attr_tensor(np.array([1], dtype='int32')),
]
self.graph.add_operator(tfl.SliceOperator(slice_inputs, [ends_mid_tensor]))
ends_left = new_end[:new_dim]
ends_right = new_end[new_dim + 1 :]
ends_tensors = []
if len(ends_left) > 0:
ends_tensors.append(self.create_attr_tensor(ends_left))
ends_tensors.append(ends_mid_tensor)
if len(ends_right) > 0:
ends_tensors.append(self.create_attr_tensor(ends_right))
if len(ends_tensors) > 1:
self.graph.add_operator(tfl.ConcatenationOperator(ends_tensors, [new_end_t], 0))
else:
new_end_t = ends_tensors[0]
else:
new_end_t = self.create_attr_tensor(new_end)
new_stride = np.ones(len(prev_shape), dtype='int32')
new_stride[new_dim] = op.inputs[3].tensor[old_dim]
new_stride_t = self.create_attr_tensor(new_stride)
actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
actions.append((self.graph.replace_operator_input, (node, 2, new_end_t, True)))
actions.append((self.graph.replace_operator_input, (node, 3, new_stride_t, True)))
elif node['node_type'] == ExtendedOperator.TILE:
old_shape = op.inputs[1].tensor
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_shape = np.ones(len(prev_shape), dtype='int32')
new_shape[new_dim] = old_shape[old_dim]
new_shape_tensor = self.create_attr_tensor(new_shape)
actions.append((self.graph.replace_operator_input, (node, 1, new_shape_tensor, True)))
elif node['node_type'] in (
ExtendedOperator.SUM,
ExtendedOperator.ARG_MIN,
ExtendedOperator.ARG_MAX,
ExtendedOperator.REDUCE_MIN,
ExtendedOperator.REDUCE_MAX,
ExtendedOperator.REDUCE_PROD,
ExtendedOperator.MEAN,
):
new_axis = prev_shape.index(-1)
axis_arr = np.array([new_axis], dtype='int32')
axis_tensor = self.create_attr_tensor(axis_arr)
actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))
elif dim_indice is not None:
raise NotImplementedError(f'{node["node_type"]} has the property `dims` but is not handled')
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
return num_actions