def filter_graph()

in lucid/scratch/pretty_graphs/graph.py [0:0]


def filter_graph(graph, keep_nodes, pass_through=True):

  new_graph = Graph()

  for node in graph.nodes:
    if node.name in keep_nodes:
      new_node = node.copy()
      new_node.graph = new_graph
      new_node.subsumed = []
      new_graph.add_node(new_node)

  def kept_inputs(node):
    ret = []
    visited = []

    def walk(inp):
      if inp in visited: return
      visited.append(inp)
      if inp.name in keep_nodes:
        ret.append(inp)
      else:
        if pass_through:
          new_graph[node].subsumed.append(inp.name)
          for inp2 in inp.inputs:
            walk(inp2)

    for inp in node.inputs:
      walk(inp)

    return ret

  for node in graph.nodes:
    if node.name in keep_nodes:
      for inp in kept_inputs(node):
        new_graph.add_edge(inp, node)

  return new_graph