def _rebuild_scope_edges_and_get_ret_vars()

in mmdnn/conversion/rewriter/folder.py [0:0]


    def _rebuild_scope_edges_and_get_ret_vars(self, scope_node):
        
        def _get_index(node ,name):
            for idx, in_edge in enumerate(node.in_edges):
                if in_edge.split(':')[0] == name:
                    return idx

        return_nodes = list()
        return_variable_names = list()

        for n_name in scope_node.topology_list:
            n = self._graph.get_node(n_name)
            for in_edge in n.in_edges:

                if not in_edge.split(':')[0] in scope_node.topology_list:
                    if not in_edge in scope_node.in_edges:
                        scope_node.in_edges.append(in_edge)

                    # in_node's out edges replace n_name with scope node name.
                    in_node = self._graph.get_node(in_edge)
                    if n_name in in_node.out_edges:
                        idx = in_node.out_edges.index(n_name)
                        in_node.out_edges.remove(n_name)
                        if scope_node.name not in in_node.out_edges:
                            in_node.out_edges.insert(idx, scope_node.name)

            for out_edge in n.out_edges:

                if not out_edge in scope_node.topology_list:
                    out_node = self._graph.get_node(out_edge)
                    parent_node_variable_name = self._graph.get_parent_variable_name(out_edge.split(
                        ':')[0], [_get_index(self._graph.get_node(out_edge), n_name)])

                    if parent_node_variable_name not in return_variable_names:
                        return_nodes.append(self._graph.get_node(n_name))
                        return_variable_names.append(parent_node_variable_name)
                    scope_node.out_edges.append(out_edge)

        # no out nodes means the last node in scope nodes should be returned
        if not return_nodes:
            return_nodes.append(self._graph.get_node(
                scope_node.topology_list[-1]))
            return_variable_names.append(self._graph.get_node(
                scope_node.topology_list[-1]).real_variable_name)

        ret_idx = 0
        for ret_node, ret_variable_name in zip(return_nodes, return_variable_names):

            subscript = '' if len(ret_variable_name.split(
                '[')) == 1 else ':'+ret_variable_name.split('[')[1].split(']')[0]

            for out_name in ret_node.out_edges:
                if not out_name in scope_node.topology_list:
                    out_node = self._graph.get_node(out_name)

                    ret_name = ret_node.name + subscript
                    if ret_name in out_node.in_edges:
                        insert_pos = out_node.in_edges.index(ret_name)
                        insert_name = scope_node.name + \
                            ':{}'.format(str(ret_idx)) if len(
                                return_variable_names) > 1 else scope_node.name
                        out_node.in_edges.remove(ret_name)
                        out_node.in_edges.insert(insert_pos, insert_name)

                        # if out_node is scope node, replace the scope node's inner topology list node.
                        if out_node.type == 'Scope':
                            for n in out_node.topology_list:
                                n = self._graph.get_node(n)
                                if ret_name in n.in_edges:
                                    idx = n.in_edges.index(ret_name)
                                    n.in_edges.remove(ret_name)
                                    n.in_edges.insert(idx, insert_name)
            ret_idx += 1

        return return_variable_names