in ppuda/deepnets1m/graph.py [0:0]
def _filter_graph(self):
r"""
Remove redundant/unsupported nodes from the automatically constructed graphs.
:return:
"""
# These ops will not be added to the graph
unsupported_modules = set()
for i, node in enumerate(self._nodes):
ind = node['param_name'].find('Backward')
name = node['param_name'][:len(node['param_name']) if ind == -1 else ind]
if type(node['module']) not in MODULES and name not in MODULES:
unsupported_modules.add(node['param_name'])
# Add ops requiring extra checks before removing
unsupported_modules = ['Mul', 'Clone'] + list(unsupported_modules) + \
['Mean', 'Add', 'Cat']
for pattern in unsupported_modules:
ind_keep = []
for i, node in enumerate(self._nodes):
op_name, attrs = node['param_name'], node['attrs']
if op_name.find(pattern) >= 0:
keep = False
if op_name.startswith('Mean'):
# Avoid adding mean operations (in CSE)
if isinstance(attrs, dict) and 'keepdim' in attrs:
keep = attrs['keepdim'] == 'True'
else:
# In pytorch <1.9 the computational graph may be inaccurate
keep = i < len(self._nodes) - 1 and not self._nodes[i + 1]['param_name'].startswith('cells.')
elif op_name.startswith('Mul'):
keep = self._nodes[i - 2]['param_name'].startswith('Hard') # CSE op
elif op_name.startswith('Clone'):
keep = self._nodes[i - 11]['param_name'].startswith('Softmax') # MSA op
elif op_name.startswith('Cat') or op_name.startswith('Add'): # Concat and Residual (Sum) ops
keep = len(np.where(self._Adj[:, i])[0]) > 1 # keep only if > 1 edges are incoming
if not keep:
# rewire edges from/to the to-be-removed node to its neighbors
for n1 in np.where(self._Adj[i, :])[0]:
for n2 in np.where(self._Adj[:, i])[0]:
if n1 != n2:
self._Adj[n2, n1] = 1
else:
keep = True
if keep:
ind_keep.append(i)
ind_keep = np.array(ind_keep)
if len(ind_keep) < self._Adj.shape[0]:
self._Adj = self._Adj[:, ind_keep][ind_keep, :]
self._nodes = [self._nodes[i] for i in ind_keep]
return