in lucid/scratch/pretty_graphs/format_graph.py [0:0]
def parse_graph(graph):
node_prevs = {}
for node in graph.nodes:
node_prevs[node.name] = set(node.inputs)
for inp in node.inputs:
node_prevs[node.name] |= node_prevs[inp.name]
node_posts = {}
for node in reversed(graph.nodes):
node_posts[node.name] = set(node.consumers)
for inp in node.consumers:
node_posts[node.name] |= node_posts[inp.name]
def GCA(node):
branches = node.inputs
branch_nodes = [set([node]) | node_prevs[node.name] for node in branches]
branch_shared = set.intersection(*branch_nodes)
return max(branch_shared, key=lambda n: graph.nodes.index(n))
def MCP(node):
branches = node.consumers
branch_nodes = [set([node]) | node_posts[node.name] for node in branches]
branch_shared = set.intersection(*branch_nodes)
return min(branch_shared, key=lambda n: graph.nodes.index(n))
def sorted_nodes(nodes):
return sorted(nodes, key = lambda n: graph.nodes.index(n))
def nodes_between(a, b):
return (node_prevs[b.name] - node_prevs[a.name]) | set([a,b])
def parse_node(node):
return LayoutNode(node)
def fake_node():
return LayoutNode(None)
def parse_list(a, b):
seq = []
pres = b
while pres is not a:
prev = GCA(pres)
seq.append(parse_node(pres))
if prev not in pres.inputs or len(pres.inputs) > 1:
seq.append(parse_branch(prev, pres))
pres = prev
seq += [parse_node(a)]
return LayoutSeq(list(reversed(seq)))
def parse_branch(start, stop):
assert GCA(stop) == start
branches = stop.inputs
branch_nodes = [ (set([node]) | node_prevs[node.name])
- (set([start]) | node_prevs[start.name])
for node in branches]
assert all([len(b1 & b2) == 0
for b1 in branch_nodes
for b2 in branch_nodes if b1 != b2])
ret = []
for nodes in branch_nodes:
nodes_seq = sorted_nodes(nodes)
if len(nodes) > 1:
ret.append(parse_list(nodes_seq[0], nodes_seq[-1]))
elif len(nodes) == 1:
ret.append(parse_node(nodes_seq[0]))
else:
ret.append(fake_node())
return LayoutBranch(ret)
return parse_list(graph.nodes[0], graph.nodes[-1])