def _build_graph()

in nni/common/graph_utils.py [0:0]


    def _build_graph(self):
        """
        Build graph using our defined format from jit trace.
        There are basically three steps: first, construct necessary information (data structures),
        second, extract all the modules to convert to node, Third, extract all functions to convert
        to node.

        Returns
        -------
        dict
            use name to index nodes, key: node name, value: node
        dict
            use input (its name) to index nodes,
            key: input, value: list of nodes that take this input
        dict
            use output (its name) to index nodes,
            key: output, value: node that generates this output
        """
        omit_useless_nodes = True
        graph = self.trace.graph
        _logger.debug(graph)
        # build input/output mapping, from input/output debugName to its node
        input_to_node = defaultdict(list)
        output_to_node = defaultdict(list)
        for node in graph.nodes():
            if node.kind() == CONSTANT_KIND:
                continue
            for x in node.outputs():
                if x.node().kind() == CONSTANT_KIND:
                    continue
                output_to_node[x.debugName()].append(node)
                assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName()
            for x in node.inputs():
                if x.node().kind() == CONSTANT_KIND:
                    continue
                input_to_node[x.debugName()].append(node)

        # build module mapping, from module name to all nodes (as list) under this module scope
        module_to_nodes = defaultdict(list)
        # the mapping of function (non-module in forward) to nodes, key is scope name
        func_to_nodes = defaultdict(list)

        nodes_py = GraphPy()
        for node in graph.inputs():
            if omit_useless_nodes:
                if not node.uses():  # number of user of the node (= number of outputs/ fanout)
                    continue

            if node.type().kind() != 'ClassType':
                nodes_py.append(NodePyIO(node, 'input'))

        self.leaf_modules = self._extract_leaf_modules()
        module_to_type = {name: parse_traced_name(
            module._name) for name, module in self.trace.named_modules()}

        # associate module name with their trace graph nodes
        for node in graph.nodes():
            if node.kind() == CONSTANT_KIND:
                continue
            module_name = self._get_module_name(node.scopeName())
            if module_name in self.leaf_modules:
                module_to_nodes[module_name].append(node)
            else:
                func_to_nodes[node.scopeName()].append(node)
        # build node group for module
        for module_name, node_cpps in module_to_nodes.items():
            use_count = 0
            merged = set()
            for node in node_cpps:
                if node not in merged:
                    # modules that have same scope name may have different locations in the
                    # graph. Futhermore, there are also lots of prim:: nodes that in node_cpps,
                    # so we also need to call the expand_module_node.
                    unique_name = module_name
                    if use_count > 0:
                        unique_name = module_name + '.%d' % use_count
                        self.reused_module.add(unique_name)
                        self.reused_module.add(module_name)
                    node_group = self._expand_module_node(
                        node, module_name, unique_name, module_to_type[module_name],
                        node_cpps, input_to_node, output_to_node, 'module')
                    nodes_py.nodes_op.append(node_group)
                    use_count += 1
                    merged.update(node_group.node_cpps)

        # each scope_name may have multiple funcs, we split them and create node for each of them
        # build node group for torch.nn.functional
        for _, nodes in func_to_nodes.items():
            # extract non prim:: nodes
            key_func_nodes = list()
            for node in nodes:
                if self._is_key_func(node):
                    # find the key function nodes
                    key_func_nodes.append(node)
            # for each non prim node, expand it
            for node in key_func_nodes:
                node_group = self._expand_key_func_node(
                    node, nodes, input_to_node, output_to_node, 'func')
                nodes_py.nodes_op.append(node_group)
                # get shape infor for view (aten::view) func
                # if node_group.op_type in ['aten::view', 'aten::flatten']:
                #     node_group.auxiliary = self._extract_shape_info(node)

        for node in graph.outputs():  # Create sink nodes for output ops
            node_py = NodePyIO(node, 'output')
            nodes_py.append(node_py)

        self.nodes_py = nodes_py
        # build index
        return self._build_index(self.nodes_py.nodes_op)