in mmdnn/conversion/rewriter/folder.py [0:0]
def _rebuild_scope_edges_and_get_ret_vars(self, scope_node):
def _get_index(node ,name):
for idx, in_edge in enumerate(node.in_edges):
if in_edge.split(':')[0] == name:
return idx
return_nodes = list()
return_variable_names = list()
for n_name in scope_node.topology_list:
n = self._graph.get_node(n_name)
for in_edge in n.in_edges:
if not in_edge.split(':')[0] in scope_node.topology_list:
if not in_edge in scope_node.in_edges:
scope_node.in_edges.append(in_edge)
# in_node's out edges replace n_name with scope node name.
in_node = self._graph.get_node(in_edge)
if n_name in in_node.out_edges:
idx = in_node.out_edges.index(n_name)
in_node.out_edges.remove(n_name)
if scope_node.name not in in_node.out_edges:
in_node.out_edges.insert(idx, scope_node.name)
for out_edge in n.out_edges:
if not out_edge in scope_node.topology_list:
out_node = self._graph.get_node(out_edge)
parent_node_variable_name = self._graph.get_parent_variable_name(out_edge.split(
':')[0], [_get_index(self._graph.get_node(out_edge), n_name)])
if parent_node_variable_name not in return_variable_names:
return_nodes.append(self._graph.get_node(n_name))
return_variable_names.append(parent_node_variable_name)
scope_node.out_edges.append(out_edge)
# no out nodes means the last node in scope nodes should be returned
if not return_nodes:
return_nodes.append(self._graph.get_node(
scope_node.topology_list[-1]))
return_variable_names.append(self._graph.get_node(
scope_node.topology_list[-1]).real_variable_name)
ret_idx = 0
for ret_node, ret_variable_name in zip(return_nodes, return_variable_names):
subscript = '' if len(ret_variable_name.split(
'[')) == 1 else ':'+ret_variable_name.split('[')[1].split(']')[0]
for out_name in ret_node.out_edges:
if not out_name in scope_node.topology_list:
out_node = self._graph.get_node(out_name)
ret_name = ret_node.name + subscript
if ret_name in out_node.in_edges:
insert_pos = out_node.in_edges.index(ret_name)
insert_name = scope_node.name + \
':{}'.format(str(ret_idx)) if len(
return_variable_names) > 1 else scope_node.name
out_node.in_edges.remove(ret_name)
out_node.in_edges.insert(insert_pos, insert_name)
# if out_node is scope node, replace the scope node's inner topology list node.
if out_node.type == 'Scope':
for n in out_node.topology_list:
n = self._graph.get_node(n)
if ret_name in n.in_edges:
idx = n.in_edges.index(ret_name)
n.in_edges.remove(ret_name)
n.in_edges.insert(idx, insert_name)
ret_idx += 1
return return_variable_names