in lucid/scratch/pretty_graphs/graph.py [0:0]
def filter_graph(graph, keep_nodes, pass_through=True):
new_graph = Graph()
for node in graph.nodes:
if node.name in keep_nodes:
new_node = node.copy()
new_node.graph = new_graph
new_node.subsumed = []
new_graph.add_node(new_node)
def kept_inputs(node):
ret = []
visited = []
def walk(inp):
if inp in visited: return
visited.append(inp)
if inp.name in keep_nodes:
ret.append(inp)
else:
if pass_through:
new_graph[node].subsumed.append(inp.name)
for inp2 in inp.inputs:
walk(inp2)
for inp in node.inputs:
walk(inp)
return ret
for node in graph.nodes:
if node.name in keep_nodes:
for inp in kept_inputs(node):
new_graph.add_edge(inp, node)
return new_graph