in lucid/scratch/pretty_graphs/graph.py [0:0]
def find_groups(graph):
node_successors = {}
for node in graph.nodes:
node_successors[node.name] = set(node.inputs)
for inp in node.inputs:
node_successors[node.name] |= node_successors[inp.name]
concat_nodes = [node for node in graph.nodes
if node.op in ["Concat", "ConcatV2", "Add"] and len(node.inputs) > 1]
groups = {}
group_children = set()
for root_node in concat_nodes:
branch_heads = root_node.inputs
branch_nodes = [set([node]) | node_successors[node.name] for node in branch_heads]
branch_shared = set.intersection(*branch_nodes)
branch_uniq = set.union(*branch_nodes) - branch_shared
groups[root_node] = set([root_node]) | branch_uniq
group_children |= branch_uniq
for root in list(groups.keys()):
if root in group_children:
del groups[root]
return groups