in tinynn/converter/operators/optimize.py [0:0]
def elementwise_reshape_transpose_passthrough_pass(self) -> int:
edges = self.graph.graph.es.select(
functools.partial(is_transpose_reshape_op_edge, graph_converter=self.graph.graph)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
processed_nodes = set()
num_actions = 0
for node in unique_nodes:
pending_processed_nodes = set()
op = node['op']
input_indices = op_input_indices(op)
l_shape = op.inputs[0].shape
r_shape = op.outputs[0].shape
if len(l_shape) == 0 or len(r_shape) == 0:
continue
l_map, r_map, _, _ = reshape_mapping(l_shape, r_shape)
mode = None
need_chain = False
for l_val, r_val in zip(l_map, r_map):
if len(l_val) > 1 and len(r_val) == 1:
if mode in (None, 'up'):
mode = 'up'
else:
mode = '?'
break
elif len(r_val) > 1 and len(l_val) == 1:
if mode in (None, 'down'):
mode = 'down'
else:
mode = '?'
break
elif len(r_val) > 1 and len(l_val) > 1:
if len(r_val) != len(l_val) or r_val != l_val:
# TODO: Support this case
mode = '?'
break
else:
need_chain = True
if mode is None:
mode = 'down'
# TODO: Support multi-multi mappings
if mode == '?':
# reset hints if passthrough is not possible
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
if 'direction' in prev_node['op'].extra_hints:
prev_node['op'].extra_hints.pop('direction')
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if 'direction' in next_node['op'].extra_hints:
next_node['op'].extra_hints.pop('direction')
continue
check_consecutive_indices = []
if need_chain:
new_l_map = []
new_r_map = []
for l_val, r_val in zip(l_map, r_map):
if len(l_val) > 1 and len(r_val) > 1:
if mode == 'down':
check_consecutive_indices.append(l_val)
else:
check_consecutive_indices.append(r_val)
for l_item in l_val:
new_l_map.append([l_item])
for r_item in r_val:
new_r_map.append([r_item])
else:
new_l_map.append(l_val)
new_r_map.append(r_val)
l_map = new_l_map
r_map = new_r_map
prev_nodes = []
cand_perms = dict()
cand_rev_perms = dict()
prev_output_indices = []
num_constant_nodes = 0
prev_hints = set()
skip = False
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
if prev_node['name'] in processed_nodes:
skip = True
break
pending_processed_nodes.add(prev_node['name'])
if mode == 'down':
perm = tuple(prev_node['op'].inputs[1].tensor.tolist())
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
elif mode == 'up':
perm = tuple(np.argsort(prev_node['op'].inputs[1].tensor).tolist())
cand_rev_perms.setdefault(perm, 0)
cand_rev_perms[perm] += 1
if 'direction' in prev_node['op'].extra_hints:
prev_hints.add(prev_node['op'].extra_hints['direction'])
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
num_constant_nodes += 1
if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints):
continue
next_nodes = []
next_edges = []
out_nodes = []
next_hints = set()
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
else:
if next_node['name'] in processed_nodes:
skip = True
break
pending_processed_nodes.add(next_node['name'])
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
if mode == 'down':
perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())
cand_rev_perms.setdefault(perm, 0)
cand_rev_perms[perm] += 1
elif mode == 'up':
perm = tuple(next_node['op'].inputs[1].tensor.tolist())
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
if 'direction' in next_node['op'].extra_hints:
next_hints.add(next_node['op'].extra_hints['direction'])
if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints):
continue
cur_transpose_size = sum(cand_perms.values()) + sum(cand_rev_perms.values())
new_transpose_size = len(prev_nodes) + len(next_nodes) - sum(cand_perms.values()) - num_constant_nodes
# Skip if the number of transpose nodes is not decreasing
if len(cand_perms) == 0 or len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
continue
elif new_transpose_size == cur_transpose_size:
skip = True
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
if 'down' in prev_hints or 'up' in next_hints:
skip = False
if skip:
continue
perm = max(cand_perms.items(), key=lambda x: x[1])[0]
perm_arr = np.array(perm, dtype='int32')
skip = False
for check_idx in check_consecutive_indices:
if mode == 'down':
target_idx = perm_arr[check_idx]
elif mode == 'up':
perm_sorter = perm_arr.argsort()
target_idx = perm_sorter[np.searchsorted(perm_arr, check_idx, sorter=perm_sorter)]
normalized_src = [x - check_idx[0] for x in check_idx]
normalized_tgt = [x - target_idx[0] for x in target_idx]
if normalized_src != normalized_tgt:
skip = True
break
if skip:
continue
num_actions += 1
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for pending_processed_node in pending_processed_nodes:
processed_nodes.add(pending_processed_node)
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
if mode == 'down':
inv_perm_arr = np.argsort(perm_arr).astype('int32')
l_dict = dict(zip([x[0] for x in l_map], r_map))
indices = map(lambda x: l_dict[x], inv_perm_arr.tolist())
inv_post_perm = list(itertools.chain.from_iterable(indices))
inv_post_perm_arr = np.array(inv_post_perm, dtype='int32')
post_perm_arr = np.argsort(inv_post_perm_arr).astype('int32')
elif mode == 'up':
r_dict = dict(zip([x[0] for x in r_map], l_map))
indices = map(lambda x: r_dict[x], perm)
inv_perm = list(itertools.chain.from_iterable(indices))
inv_perm_arr = np.array(inv_perm, dtype='int32')
post_perm_arr = np.argsort(perm_arr).astype('int32')
inv_post_perm_arr = np.argsort(post_perm_arr).astype('int32')
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
perm_tensor = self.create_attr_tensor(inv_perm_arr)
prev_new_out = self.create_transform_tensor(
np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
)
transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(transpose_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
perm_tensor = self.create_attr_tensor(post_perm_arr)
new_out = self.create_transform_tensor(
np.transpose(op_out.tensor, inv_post_perm_arr), quantization=op_out.quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
transpose_op.extra_hints['direction'] = 'down'
self.graph.add_operator(transpose_op)
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
# OP specific dim handling logic
old_shape = op.inputs[1].tensor
new_shape = self.create_attr_tensor(old_shape[inv_post_perm_arr])
actions.append((self.graph.replace_operator_input, (node, 1, new_shape, True)))
op.newShape = new_shape.tensor
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
return num_actions