in onnxconverter_common/optimizer.py [0:0]
def find_local(node):
solution = None
if node.is_transpose:
perm = Solution.get_perm(node.origin)
if node.in_single_path: # node.in_single_path_and_inner:
if Solution.is_useless_transpose(perm):
solution = Solution(node.get_precedence_by_idx(0), node, node, node.successor[0])
return solution
else:
succ = node.successor[0] # type: LinkedNode
while succ.in_single_path:
if succ.is_transpose:
break
if succ.element_wise or succ.broadcast:
succ = succ.successor[0]
else:
break
if succ.is_transpose:
solution = MergeSolution(node.get_precedence_by_idx(0), node, succ, succ.successor)
return solution
last_switchable = node
test_node = node.successor[0]
switch_transpose = False
while test_node.is_transpose_switchable_single_path and not test_node.successor[0].in_or_out:
switch_transpose = True
last_switchable = test_node
test_node = test_node.successor[0]
if switch_transpose:
solution = MoveForwardSolution(node.get_precedence_by_idx(0), node, last_switchable,
last_switchable.successor[0])
return solution
next_node = node.successor[0]
if next_node.is_transpose_switchable_simo:
delta_node = -1
cur_perm = Solution.get_perm(node.origin)
for branch in next_node.successor:
while branch.is_transpose_switchable_single_path:
branch = branch.successor[0]
if branch.is_transpose:
branch_perm = Solution.get_perm(branch.origin)
if len(cur_perm) == len(branch_perm):
perm_f = [cur_perm[idx] for idx in branch_perm]
if Solution.is_useless_transpose(perm_f):
delta_node = delta_node - 1
else:
delta_node = delta_node + 1
if delta_node <= 0:
solution = FanOutSolution(node.get_precedence_by_idx(0), node, next_node, next_node)
return solution
else: # simo Transpose op
simo_transpose_case = True
for succ_ in node.successor:
if not succ_.is_transpose:
simo_transpose_case = False
break
if simo_transpose_case:
solution = FanOutSolution(node.get_precedence_by_idx(0), node, node, node.successor)
return solution
elif node.is_transpose_switchable_mi:
branch_perm = []
number_branch = 0
good_branch = 0
for branch in node.precedence:
if branch.is_transpose and branch.in_single_path_and_inner:
if number_branch == 0:
branch_perm = Solution.get_perm(branch.origin)
good_branch = good_branch + 1
else:
cur_perm = Solution.get_perm(branch.origin)
if not branch_perm == cur_perm:
break
good_branch = good_branch + 1
else:
break
number_branch = number_branch + 1
find_switch = good_branch == len(node.precedence)
if find_switch:
solution = FanInSolution(node, node.successor[0], None, None, branch_perm)
return solution
eligible_concat = node.is_eligible_concat_and_inner
if eligible_concat[0]:
perm = Solution.get_perm(node.get_precedence_by_idx(0).origin)
solution = FanInSolution(node, node.successor[0], None, None, perm)
onnx_node = helper.make_node('Concat', node.origin.input, node.origin.output,
node.origin.name, axis=eligible_concat[1])
node.origin = onnx_node
return solution
return solution