in tinynn/converter/operators/optimize.py [0:0]
def elementwise_op_transpose_passthrough_pass(self, quantizable_ops_only: bool = False) -> int:
edges = self.graph.graph.es.select(
functools.partial(
is_transpose_elementwise_op_edge,
graph_converter=self.graph.graph,
quantizable_ops_only=quantizable_ops_only,
)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
if quantizable_ops_only:
all_edges = self.graph.graph.es.select(
functools.partial(
is_transpose_elementwise_op_edge,
graph_converter=self.graph.graph,
quantizable_ops_only=False,
)
)
all_pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in all_edges)
forward_d = dict(all_pairs)
backward_d = {v: k for k, v in forward_d.items()}
filtered_nodes = []
for s, e in pairs:
if s['node_type'] == ExtendedOperator.TRANSPOSE:
pn = backward_d.get(s, None)
if pn is not None:
filtered_nodes.append(pn)
else:
log.warning('Cannot passthrough transpose upward around requantizable ops')
else:
pn = forward_d.get(e, None)
if pn is not None:
filtered_nodes.append(pn)
else:
log.warning('Cannot passthrough transpose downward around requantizable ops')
else:
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
num_actions = 0
for node in unique_nodes:
op = node['op']
input_indices = op_input_indices(op)
prev_nodes = []
cand_perms = dict()
prev_output_indices = []
num_constant_nodes = 0
num_reshape_transpose = 0
prev_hints = set()
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
perm = tuple(prev_node['op'].inputs[1].tensor.tolist())
if node['node_type'] == ExtendedOperator.PACK:
perm = [i if i < op.axis else i + 1 for i in perm]
perm.insert(op.axis, op.axis)
perm = tuple(perm)
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
if 'direction' in prev_node['op'].extra_hints:
prev_hints.add(prev_node['op'].extra_hints['direction'])
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
num_constant_nodes += 1
if prev_node['node_type'] == ExtendedOperator.RESHAPE:
prev_prev_node_name = self.graph.tensor_node_map[prev_node['op'].inputs[0].name]
prev_prev_node = self.graph.graph.vs.find(name=prev_prev_node_name)
if prev_prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
num_reshape_transpose += 1
if 'direction' in prev_prev_node['op'].extra_hints:
prev_hints.add(prev_prev_node['op'].extra_hints['direction'])
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
continue
next_nodes = []
next_edges = []
out_nodes = []
skip_names = []
next_hints = set()
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
skip_names.append(edge['label'])
else:
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())
if node['node_type'] == ExtendedOperator.UNPACK:
perm = [i if i < op.axis else i + 1 for i in perm]
perm.insert(op.axis, op.axis)
perm = tuple(perm)
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
if 'direction' in next_node['op'].extra_hints:
next_hints.add(next_node['op'].extra_hints['direction'])
if next_node['node_type'] == ExtendedOperator.RESHAPE:
o_nodes = [e.target_vertex for e in next_node.out_edges()]
if len(o_nodes) == 1 and o_nodes[0]['node_type'] == ExtendedOperator.TRANSPOSE:
num_reshape_transpose += 1
if 'direction' in o_nodes[0]['op'].extra_hints:
next_hints.add(o_nodes[0]['op'].extra_hints['direction'])
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
continue
cur_transpose_size = sum(cand_perms.values()) + num_reshape_transpose
new_transpose_size = (
len(prev_nodes) + len(next_nodes) - num_constant_nodes - cur_transpose_size + num_reshape_transpose
)
# Skip if the following conditions are met
# a. the number of transpose nodes is not decreasing (skip if `bypass_elementwise_passthrough_constraint`)
# b. no hint can be found (skip if optimize level is below BRANCH_OPTIMIZE_EXTENDED)
is_increasing = new_transpose_size > cur_transpose_size
is_not_decreasing = new_transpose_size >= cur_transpose_size
is_same = new_transpose_size == cur_transpose_size
if len(next_nodes) == 0:
continue
else:
if self.bypass_elementwise_passthrough_constraint:
condition = is_not_decreasing
else:
if is_increasing:
continue
condition = is_same
if condition:
skip = True
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
if 'down' in prev_hints or 'up' in next_hints:
skip = False
if skip:
continue
num_actions += 1
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
perm = max(cand_perms.items(), key=lambda x: x[1])[0]
perm_arr = np.array(perm, dtype='int32')
inv_perm_arr = np.argsort(perm_arr).astype('int32')
if node['node_type'] == ExtendedOperator.UNPACK:
inv_perm_arr_post = inv_perm_arr[inv_perm_arr != op.axis]
inv_perm_arr_post[inv_perm_arr_post > op.axis] -= 1
perm_arr_post = np.argsort(inv_perm_arr_post).astype('int32')
elif node['node_type'] == ExtendedOperator.PACK:
perm_arr_post = perm_arr
inv_perm_arr_post = inv_perm_arr
perm_arr = perm_arr_post[perm_arr_post != op.axis]
perm_arr[perm_arr > op.axis] -= 1
inv_perm_arr = np.argsort(perm_arr).astype('int32')
else:
perm_arr_post = perm_arr
inv_perm_arr_post = inv_perm_arr
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
perm_tensor = self.create_attr_tensor(inv_perm_arr)
if len(prev_out.shape) != perm_tensor.tensor.size:
new_shape = [1] * (perm_tensor.tensor.size - len(prev_out.shape)) + list(prev_out.shape)
prev_out_reshaped = self.create_transform_tensor(
np.reshape(prev_out.tensor, new_shape), quantization=prev_out.quantization
)
new_shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
self.graph.add_operator(
tfl.ReshapeOperator([prev_out, new_shape_tensor], [prev_out_reshaped], new_shape)
)
prev_out = prev_out_reshaped
prev_new_out = self.create_transform_tensor(
np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(transpose_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
# For unused tensors, we perform inplace shape updates
if op_out.name in skip_names:
orig_shape = np.array(op_out.shape, dtype='int32')
new_shape = orig_shape[inv_perm_arr]
op_out.shape = tuple(new_shape.tolist())
continue
perm_tensor = self.create_attr_tensor(perm_arr_post)
new_out = self.create_transform_tensor(
np.transpose(op_out.tensor, inv_perm_arr_post), quantization=op_out.quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
transpose_op.extra_hints['direction'] = 'down'
self.graph.add_operator(transpose_op)
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
# OP specific dim handling logic
if node['node_type'] in (ExtendedOperator.CONCATENATION, ExtendedOperator.GATHER, ExtendedOperator.UNPACK):
old_axis = op.axis
new_axis = np.where(inv_perm_arr == old_axis)[0][0]
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.PACK:
old_axis = op.axis
new_axis = np.where(inv_perm_arr_post == old_axis)[0][0]
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.SPLIT_V:
old_dim = op.inputs[2].tensor
new_dim = np.where(inv_perm_arr == old_dim)[0][0]
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
elif node['node_type'] == ExtendedOperator.SPLIT:
old_dim = op.inputs[0].tensor
new_dim = np.where(inv_perm_arr == old_dim)[0][0]
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
elif node['node_type'] in (
ExtendedOperator.PAD,
ExtendedOperator.PADV2,
ExtendedOperator.MIRROR_PAD,
ExtendedOperator.TILE,
):
old_pad = op.inputs[1].tensor
new_pad = self.create_attr_tensor(old_pad[inv_perm_arr])
actions.append((self.graph.replace_operator_input, (node, 1, new_pad, True)))
elif node['node_type'] == ExtendedOperator.PRELU:
old_weight = op.inputs[1].tensor
if old_weight.ndim != 1:
assert old_weight.ndim + 1 == len(inv_perm_arr)
new_perm = np.argsort(np.argsort(inv_perm_arr[1:]))
new_perm_t = self.create_attr_tensor(np.array(new_perm, dtype='int32'))
new_weight = self.create_transform_tensor(np.transpose(old_weight, new_perm))
self.graph.add_operator(tfl.TransposeOperator([op.inputs[1], new_perm_t], [new_weight]))
actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
elif node['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE):
for i, t in enumerate(op.inputs[1:]):
if t.buffer is None:
new_perm_t = self.create_attr_tensor(np.array(inv_perm_arr, dtype='int32'))
new_t = self.create_transform_tensor(t.tensor[inv_perm_arr])
self.graph.add_operator(tfl.TransposeOperator([t, new_perm_t], [new_t]))
else:
new_t = self.create_attr_tensor(t.tensor[inv_perm_arr])
actions.append((self.graph.replace_operator_input, (node, i + 1, new_t, True)))
elif node['node_type'] in (
ExtendedOperator.SUM,
ExtendedOperator.ARG_MIN,
ExtendedOperator.ARG_MAX,
ExtendedOperator.REDUCE_MIN,
ExtendedOperator.REDUCE_MAX,
ExtendedOperator.REDUCE_PROD,
ExtendedOperator.MEAN,
):
old_axis = op.inputs[1].tensor.tolist()
new_axis = []
for t in old_axis:
new_t = np.where(inv_perm_arr == t)[0][0]
new_axis.append(new_t)
axis_arr = np.array(new_axis, dtype='int32')
axis_tensor = self.create_attr_tensor(axis_arr)
actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
return num_actions