in fairscale/experimental/nn/distributed_pipeline/trace.py [0:0]
def create_graph(self) -> PipelineModulesGraph:
# node_to_data maps nodes to the data they represent
node_to_data: Dict[torch.fx.Node, PipelineModulesGraph.DataSourceSpec] = {}
remote_module_nodes: List[Tuple[torch.fx.Node, RemoteModule]] = []
for node in self.tracer.graph.nodes:
if node.op == "call_module":
module = self.get_module(node)
assert isinstance(module, RemoteModule)
node_to_data[node] = module
remote_module_nodes.append((node, module))
elif node.target == operator.__getitem__ and node.op == "call_function":
assert node.args[0] in node_to_data
d = node_to_data[node.args[0]]
assert isinstance(d, RemoteModule)
node_to_data[node] = (d, node.args[1])
elif node.op == "placeholder":
arg_names = list(inspect.signature(self.tracer.root.forward).parameters)
node_to_data[node] = arg_names.index(node.target)
elif node.op == "output":
pass
else:
assert False, "Invalid node %s" % node
# Following dict stores cardinality of the output for modules: for each module, it stores None if the output
# is a simple tensor, and it stores the number of tensors in the the output if it is a tuple.
module_to_num_outputs: Dict[nn.Module, Optional[int]] = {}
for node, _ in remote_module_nodes:
# iterate over inputs to the module.
for arg in node.args:
data = node_to_data[arg]
if isinstance(data, int):
continue
if isinstance(data, RemoteModule):
assert module_to_num_outputs.get(data, None) is None
module_to_num_outputs[data] = None
else:
module, output_num = data
# Here we discovered that the output number "output_num" is used,
# so the number of outputs should be at least "output_num+1".
if module in module_to_num_outputs:
prev_value = module_to_num_outputs[module]
assert prev_value is not None
module_to_num_outputs[module] = max(prev_value, output_num + 1)
else:
module_to_num_outputs[module] = output_num + 1
graph = PipelineModulesGraph()
for node, module in remote_module_nodes:
inputs = [node_to_data[arg] for arg in node.args]
graph.add_layer(module, inputs, module_to_num_outputs.get(module))
return graph