in onnxconverter_common/optimizer.py [0:0]
def apply(self, node_list):
if self.begin_n.is_reserved:
return None, False
cur_perm = Solution.get_perm(self.begin_n.origin)
cur_perm_map = {self.begin_n.unique_name: cur_perm}
candidate_queue = list()
visited = set()
for successor_ in self.begin_n.successor:
candidate_queue.append((successor_, self.begin_n))
node_transpose_no_pass = list()
node_transpose_pass = list()
node_transpose_pass_name = {self.begin_n.unique_name}
while len(candidate_queue) > 0:
(node, prev) = candidate_queue.pop(0)
if node.unique_name in visited:
continue
visited.add(node.unique_name)
if _transpose_pass(node):
node_transpose_pass_name.add(node.unique_name)
node_transpose_pass.append((node, prev))
for successor_ in node.successor:
candidate_queue.append((successor_, node))
else:
node_transpose_no_pass.append((node, prev))
for node_pair_ in node_transpose_pass:
node = node_pair_[0]
if node.origin.op_type in _broadcast_types:
success = _check_transpose_pass_broadcast(node, node_transpose_pass_name, cur_perm_map)
if not success:
return None, False
elif node.origin.op_type == 'Unsqueeze':
unsqueeze_axes = _get_axes_from_Squeeze_Unsqueeze(node)
if unsqueeze_axes and len(unsqueeze_axes) > 1:
return None, False
# add transpose
if len(self.begin_n.successor) == 1:
for node_pair_ in node_transpose_no_pass:
(node, prev) = node_pair_
if prev.unique_name == self.begin_n.unique_name:
return None, False
for node_pair_ in node_transpose_no_pass:
if len(node_pair_[0].precedence) > 1:
pred_count = 0
for pred_ in node_pair_[0].precedence:
if pred_.origin is not None:
pred_count += 1
elif len(pred_.tensors) == 0: # not an initializer
pred_count += 1
if pred_count > 1:
return None, False
for node_pair_ in node_transpose_pass:
(node, prev) = node_pair_
node_list, cur_perm_map = _process_transpose_pass_node(node, node_list, node_transpose_pass_name,
cur_perm_map)
for node_pair_ in node_transpose_no_pass:
node = node_pair_[0]
if node.origin is None:
prev = node_pair_[1]
successor_list = list(prev.successor)
output_name = ''
for suc in successor_list:
if suc.origin is None:
output_name = list(prev.output.values())[0]
push_transpose_in_node_output_name = 'push_transpose_out_' + str(
PushTransposeSolution.transpose_number)
prev.out_redirect(output_name, push_transpose_in_node_output_name)
PushTransposeSolution.transpose_number += 1
for suc_2 in successor_list:
suc_2.in_redirect(output_name, push_transpose_in_node_output_name)
transpose_output_name = [output_name]
else:
transpose_output_name = ['push_transpose_out_' + str(PushTransposeSolution.transpose_number)]
for prev in node.precedence:
if prev.origin is not None and prev.unique_name in cur_perm_map:
cur_perm = cur_perm_map[prev.unique_name]
nnode = LinkedNode(
helper.make_node(
'Transpose',
['push_transpose_in_' + str(PushTransposeSolution.transpose_number)],
transpose_output_name,
perm=cur_perm,
name='PushTranspose_' + str(PushTransposeSolution.transpose_number)))
PushTransposeSolution.transpose_number += 1
node_list = Solution.add_siso_node(node_list, prev, node, list(prev.output.values())[0], nnode)
node_list = Solution.delete_node_1ton(node_list, self.begin, self.begin_n, self.end_p)
return node_list, True