def find_groups()

in lucid/scratch/pretty_graphs/graph.py [0:0]


def find_groups(graph):

  node_successors = {}
  for node in graph.nodes:
    node_successors[node.name] = set(node.inputs)
    for inp in node.inputs:
      node_successors[node.name] |= node_successors[inp.name]

  concat_nodes = [node for node in graph.nodes
                  if node.op in ["Concat", "ConcatV2", "Add"] and len(node.inputs) > 1]

  groups = {}
  group_children = set()
  for root_node in concat_nodes:
    branch_heads = root_node.inputs
    branch_nodes = [set([node]) | node_successors[node.name] for node in branch_heads]
    branch_shared = set.intersection(*branch_nodes)
    branch_uniq = set.union(*branch_nodes) - branch_shared
    groups[root_node] = set([root_node]) | branch_uniq
    group_children |= branch_uniq

  for root in list(groups.keys()):
    if root in group_children:
      del groups[root]

  return groups