def parse_graph()

in lucid/scratch/pretty_graphs/format_graph.py [0:0]


def parse_graph(graph):

  node_prevs = {}
  for node in graph.nodes:
    node_prevs[node.name] = set(node.inputs)
    for inp in node.inputs:
      node_prevs[node.name] |= node_prevs[inp.name]

  node_posts = {}
  for node in reversed(graph.nodes):
    node_posts[node.name] = set(node.consumers)
    for inp in node.consumers:
      node_posts[node.name] |= node_posts[inp.name]


  def GCA(node):
    branches = node.inputs
    branch_nodes  = [set([node]) | node_prevs[node.name] for node in branches]
    branch_shared =  set.intersection(*branch_nodes)
    return max(branch_shared, key=lambda n: graph.nodes.index(n))

  def MCP(node):
    branches = node.consumers
    branch_nodes  = [set([node]) | node_posts[node.name] for node in branches]
    branch_shared =  set.intersection(*branch_nodes)
    return min(branch_shared, key=lambda n: graph.nodes.index(n))

  def sorted_nodes(nodes):
    return sorted(nodes, key = lambda n: graph.nodes.index(n))

  def nodes_between(a, b):
    return (node_prevs[b.name] - node_prevs[a.name]) | set([a,b])

  def parse_node(node):
    return LayoutNode(node)

  def fake_node():
    return LayoutNode(None)

  def parse_list(a, b):
    seq = []
    pres = b
    while pres is not a:
      prev = GCA(pres)
      seq.append(parse_node(pres))
      if prev not in pres.inputs or len(pres.inputs) > 1:
        seq.append(parse_branch(prev, pres))
      pres = prev
    seq += [parse_node(a)]
    return LayoutSeq(list(reversed(seq)))


  def parse_branch(start, stop):
    assert GCA(stop) == start

    branches = stop.inputs
    branch_nodes  = [ (set([node]) | node_prevs[node.name])
                     - (set([start]) | node_prevs[start.name])
                     for node in branches]


    assert all([len(b1 & b2) == 0
                for b1 in branch_nodes
                for b2 in branch_nodes if b1 != b2])

    ret = []
    for nodes in branch_nodes:
      nodes_seq = sorted_nodes(nodes)
      if len(nodes) > 1:
        ret.append(parse_list(nodes_seq[0], nodes_seq[-1]))
      elif len(nodes) == 1:
        ret.append(parse_node(nodes_seq[0]))
      else:
        ret.append(fake_node())

    return LayoutBranch(ret)

  return parse_list(graph.nodes[0], graph.nodes[-1])