in nni/common/graph_utils.py [0:0]
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to node, Third, extract all functions to convert
to node.
Returns
-------
dict
use name to index nodes, key: node name, value: node
dict
use input (its name) to index nodes,
key: input, value: list of nodes that take this input
dict
use output (its name) to index nodes,
key: output, value: node that generates this output
"""
omit_useless_nodes = True
graph = self.trace.graph
_logger.debug(graph)
# build input/output mapping, from input/output debugName to its node
input_to_node = defaultdict(list)
output_to_node = defaultdict(list)
for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
for x in node.outputs():
if x.node().kind() == CONSTANT_KIND:
continue
output_to_node[x.debugName()].append(node)
assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName()
for x in node.inputs():
if x.node().kind() == CONSTANT_KIND:
continue
input_to_node[x.debugName()].append(node)
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = defaultdict(list)
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if not node.uses(): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != 'ClassType':
nodes_py.append(NodePyIO(node, 'input'))
self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name(
module._name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes
for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node)
else:
func_to_nodes[node.scopeName()].append(node)
# build node group for module
for module_name, node_cpps in module_to_nodes.items():
use_count = 0
merged = set()
for node in node_cpps:
if node not in merged:
# modules that have same scope name may have different locations in the
# graph. Futhermore, there are also lots of prim:: nodes that in node_cpps,
# so we also need to call the expand_module_node.
unique_name = module_name
if use_count > 0:
unique_name = module_name + '.%d' % use_count
self.reused_module.add(unique_name)
self.reused_module.add(module_name)
node_group = self._expand_module_node(
node, module_name, unique_name, module_to_type[module_name],
node_cpps, input_to_node, output_to_node, 'module')
nodes_py.nodes_op.append(node_group)
use_count += 1
merged.update(node_group.node_cpps)
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
for _, nodes in func_to_nodes.items():
# extract non prim:: nodes
key_func_nodes = list()
for node in nodes:
if self._is_key_func(node):
# find the key function nodes
key_func_nodes.append(node)
# for each non prim node, expand it
for node in key_func_nodes:
node_group = self._expand_key_func_node(
node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
# if node_group.op_type in ['aten::view', 'aten::flatten']:
# node_group.auxiliary = self._extract_shape_info(node)
for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
nodes_py.append(node_py)
self.nodes_py = nodes_py
# build index
return self._build_index(self.nodes_py.nodes_op)