in tinynn/converter/operators/optimize.py [0:0]
def output_transpose_pass(self):
nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32')
nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32')
if isinstance(self.graph.output_transpose, (list, tuple)):
assert len(self.graph.output_transpose) == len(self.graph.outputs)
else:
self.graph.output_transpose = [self.graph.output_transpose] * len(self.graph.outputs)
filtered_dict = {}
for i, (name, transpose) in enumerate(zip(self.graph.outputs, self.graph.output_transpose)):
if name in filtered_dict:
old_transpose = filtered_dict[name]
assert (
transpose == old_transpose
), f"outputs {i} points to an exising tensor {name}, but their property `output_transpose` is different"
else:
filtered_dict[name] = transpose
prev_modify_node_indices = {}
prev_modify_next_indices = {}
next_modify_node_indices = {}
for name, transpose in filtered_dict.items():
if name in self.graph.tensor_map:
tensor = self.graph.tensor_map[name]
if transpose is None:
transpose = len(tensor.shape) == 4
else:
transpose = False
for i, n in enumerate(self.graph.outputs):
if name == n:
self.graph.output_transpose[i] = transpose
if transpose:
node_name = self.graph.tensor_node_map[name]
node = self.graph.graph.vs.find(name=node_name)
tensor_idx = node['outputs'].index(name)
prev_node = None
if node['node_type'] == ExtendedOperator.DEQUANTIZE:
prev_node_name = self.graph.tensor_node_map[node['op'].inputs[0].name]
prev_node = self.graph.graph.vs.find(name=prev_node_name)
if prev_node is None:
next_modify_node_indices.setdefault(node, set())
next_modify_node_indices[node].add(tensor_idx)
else:
prev_modify_node_indices.setdefault(node, set())
prev_modify_node_indices[node].add(0)
prev_modify_next_indices.setdefault(node, set())
prev_modify_next_indices[node].add(tensor_idx)
remove_edges = []
remove_vertices = []
actions = []
for node, index in prev_modify_node_indices.items():
next_indices = prev_modify_next_indices[node]
op = node['op']
tensor_names = [node['outputs'][i] for i in index]
next_nodes = {}
for edge in node.out_edges():
if edge['label'] not in tensor_names:
continue
if edge.index in remove_edges:
continue
tensor_idx = tensor_names.index(edge['label'])
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] not in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
next_nodes.setdefault(tensor_idx, [])
next_nodes[tensor_idx].append(next_node)
prev_nodes = []
prev_output_indices = []
for i in index:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, index, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
prev_new_out = self.create_transform_tensor(
np.transpose(prev_out.tensor, nchw2nhwc_perm), quantization=prev_out.quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
prev_transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
prev_transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(prev_transpose_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_mapping = {}
for i in next_indices:
t = op.outputs[i]
t.tensor = np.transpose(t.tensor, nchw2nhwc_perm)
t.shape = t.tensor.shape
if i in next_nodes:
new_t = self.create_transform_tensor(np.transpose(t.tensor, nhwc2nchw_perm))
perm_t = self.create_attr_tensor(nhwc2nchw_perm)
next_transpose_op = tfl.TransposeOperator([t, perm_t], [new_t])
next_transpose_op.extra_hints['direction'] = 'down'
self.graph.add_operator(next_transpose_op)
tensor_mapping[t.name] = new_t
for nodes in next_nodes.values():
for n in nodes:
next_op = n['op']
for i, t in enumerate(next_op.inputs):
if t.name in tensor_mapping:
actions.append((self.graph.replace_operator_input, (n, i, tensor_mapping[t.name])))
for node, index in next_modify_node_indices.items():
op = node['op']
tensor_names = [node['outputs'][i] for i in index]
out_nodes = []
next_nodes = []
next_edges = []
for edge in node.out_edges():
if edge['label'] not in tensor_names:
continue
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
tensor_idx = tensor_names.index(edge['label'])
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
elif next_node['node_type'] != ExtendedOperator.UNUSED_NODE:
next_nodes.append(next_node)
next_edges.append(edge)
remove_vertices.extend([x.index for x in out_nodes])
remove_edges.extend([x.index for x in next_edges])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
if i not in index:
continue
op_out.tensor = np.transpose(op_out.tensor, nchw2nhwc_perm)
op_out.shape = op_out.tensor.shape
perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
new_out = self.create_transform_tensor(
np.transpose(op_out.tensor, nhwc2nchw_perm), quantization=op_out.quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
next_transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
next_transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(next_transpose_op)
tensor_node_dict[op_out.name] = (
self.graph.graph.vs.find(name=self.graph.tensor_node_map[new_out.name]),
new_out.name,
)
# Connect next edges and replace next tensors
for edge in next_edges:
old_name = edge['name']
source, new_name = tensor_node_dict[old_name]
target = edge.target_vertex
self.graph.graph.add_edge(source, target, name=new_name, label=new_name)
op = target['op']
for i, op_input in enumerate(op.inputs):
if op_input.name == old_name:
op.inputs[i] = self.graph.tensor_map[new_name]
break
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)