in ppuda/deepnets1m/graph.py [0:0]
def _build_graph(self):
r"""
Constructs a graph of a neural network in the automatic way.
This function is written based on Sergey Zagoruyko's https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py (MIT License)
PyTorch 1.9+ is required to run this script correctly for some architectures.
Currently, the function is not written very clearly and may be improved.
"""
param_map = {id(weight): (name, module) for name, (weight, module) in self._named_modules().items()}
nodes, edges, seen = [], [], {}
def get_attr(fn):
attrs = dict()
for attr in dir(fn):
if not attr.startswith('_saved_'):
continue
val = getattr(fn, attr)
attr = attr[len('_saved_'):]
if torch.is_tensor(val):
attrs[attr] = "[saved tensor]"
elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val):
attrs[attr] = "[saved tensors]"
else:
attrs[attr] = str(val)
return attrs
def traverse_graph(fn):
assert not torch.is_tensor(fn)
if fn in seen:
return seen[fn]
fn_name = str(type(fn).__name__)
node_link, link_start = None, None
if fn_name.find('AccumulateGrad') < 0:
leaf_nodes = []
for u in fn.next_functions:
if u[0] is not None:
if hasattr(u[0], 'variable'):
var = u[0].variable
name, module = param_map[id(var)]
if type(module) in NormLayers and name.find('.bias') >= 0:
continue # do not add biases of NormLayers as nodes
leaf_nodes.append({'id': u[0],
'param_name': name,
'attrs': {'size': var.size()},
'module': module})
assert len(u[0].next_functions) == 0
if len(leaf_nodes) == 0:
leaf_nodes.append({'id': fn,
'param_name': fn_name,
'attrs': get_attr(fn),
'module': None})
assert not hasattr(fn, 'variable'), fn.variable
for leaf in leaf_nodes:
node_link = str(id(leaf['id']))
if link_start is None:
link_start = node_link
seen[leaf['id']] = (node_link, leaf['param_name'])
nodes.append({'id': node_link,
'param_name': leaf['param_name'],
'attrs': leaf['attrs'],
'module': leaf['module']})
seen[fn] = (node_link, fn_name)
# recurse
if hasattr(fn, 'next_functions'):
for u in fn.next_functions:
if u[0] is not None:
link_, name_ = traverse_graph(u[0])
if link_ is not None and link_start != link_:
edges.append((link_start, link_) if name_.find('bias') >= 0 else (link_, link_start))
return node_link, fn_name
var = self.model(torch.randn(1, 3, *self.expected_image_sz))
# take only the first output, but can in principle handle multiple outputs, e.g. from auxiliary classifiers
traverse_graph((var[0] if isinstance(var, (tuple, list)) else var).grad_fn) # populate nodes and edges
nodes_lookup = { node['id']: i for i, node in enumerate(nodes) }
A = np.zeros((len(nodes) + 1, len(nodes) + 1)) # +1 for the input node added below
for out_node_id, in_node_id in edges:
A[nodes_lookup[out_node_id], nodes_lookup[in_node_id]] = 1
# Fix fc layers nodes and edge directions
for i, node in enumerate(nodes):
if isinstance(node['module'], nn.Linear) and node['param_name'].find('.weight') >= 0:
# assert node['module'].bias is not None, ('this rewiring may not work in case of no biases', node)
for out_neigh in np.where(A[i, :])[0]: # all nodes where there is an edge from i
A[np.where(A[:, out_neigh])[0], i] = 1 # rewire edges coming to out_neigh (bias) to node i (weight)
# A[i, i] = 0 # remove loop
A[:, out_neigh] = 0 # remove all edges to out_neigh except for the edge from i to out_neigh
A[i, out_neigh] = 1
# Add input node
nodes.append({'id': 'input', 'param_name': 'input', 'attrs': None, 'module': None})
ind = np.where(A[:, :-1].sum(0) == 0)[0]
assert len(ind) == 1, ind
A[-1, ind] = 1
# Sort nodes in a topological order consistent with forward propagation
A[np.diag_indices_from(A)] = 0
ind = np.array(list(nx.topological_sort(nx.DiGraph(A))))
nodes = [nodes[i] for i in ind]
A = A[ind, :][:, ind]
# Adjust graph for Transformers to be consistent with our original code
for i, node in enumerate(nodes):
if isinstance(node['module'], PosEnc):
nodes.insert(i + 1, { 'id': 'sum_pos_enc', 'param_name': 'AddBackward0', 'attrs': None, 'module': None })
A = np.insert(A, i, 0, axis=0)
A = np.insert(A, i, 0, axis=1)
A[i, i + 1] = 1 # pos_enc to sum
self._Adj = A
self._nodes = nodes
return