def _filter_graph()

in ppuda/deepnets1m/graph.py [0:0]


    def _filter_graph(self):
        r"""
        Remove redundant/unsupported nodes from the automatically constructed graphs.
        :return:
        """

        # These ops will not be added to the graph
        unsupported_modules = set()
        for i, node in enumerate(self._nodes):
            ind = node['param_name'].find('Backward')
            name = node['param_name'][:len(node['param_name']) if ind == -1 else ind]
            if type(node['module']) not in MODULES and name not in MODULES:
                unsupported_modules.add(node['param_name'])

        # Add ops requiring extra checks before removing
        unsupported_modules = ['Mul', 'Clone'] + list(unsupported_modules) + \
                              ['Mean', 'Add', 'Cat']

        for pattern in unsupported_modules:

            ind_keep = []

            for i, node in enumerate(self._nodes):
                op_name, attrs = node['param_name'], node['attrs']

                if op_name.find(pattern) >= 0:

                    keep = False
                    if op_name.startswith('Mean'):
                        # Avoid adding mean operations (in CSE)
                        if isinstance(attrs, dict) and 'keepdim' in attrs:
                            keep = attrs['keepdim'] == 'True'
                        else:
                            # In pytorch <1.9 the computational graph may be inaccurate
                            keep = i < len(self._nodes) - 1 and not self._nodes[i + 1]['param_name'].startswith('cells.')

                    elif op_name.startswith('Mul'):
                        keep = self._nodes[i - 2]['param_name'].startswith('Hard')      # CSE op

                    elif op_name.startswith('Clone'):
                        keep = self._nodes[i - 11]['param_name'].startswith('Softmax')  # MSA op

                    elif op_name.startswith('Cat') or op_name.startswith('Add'):        # Concat and Residual (Sum) ops
                        keep = len(np.where(self._Adj[:, i])[0]) > 1  # keep only if > 1 edges are incoming

                    if not keep:
                        # rewire edges from/to the to-be-removed node to its neighbors
                        for n1 in np.where(self._Adj[i, :])[0]:
                            for n2 in np.where(self._Adj[:, i])[0]:
                                if n1 != n2:
                                    self._Adj[n2, n1] = 1
                else:
                    keep = True

                if keep:
                    ind_keep.append(i)

            ind_keep = np.array(ind_keep)

            if len(ind_keep) < self._Adj.shape[0]:
                self._Adj = self._Adj[:, ind_keep][ind_keep, :]
                self._nodes = [self._nodes[i] for i in ind_keep]

        return