in tensorwatch/model_graph/hiddenlayer/ge.py [0:0]
def match(self, graph, nodes):
if not nodes:
return [], None
nodes = nodes if isinstance(nodes, list) else [nodes]
# If a single node, assume we need to match with its siblings
if len(nodes) == 1:
nodes = graph.siblings(nodes[0])
else:
# Verify all nodes have the same parent or all have no parent
parents = [graph.incoming(n) for n in nodes]
matches = [set(p) == set(parents[0]) for p in parents[1:]]
if not all(matches):
return [], None
# TODO: If more nodes than patterns, we should consider
# all permutations of the nodes
if len(self.patterns) != len(nodes):
return [], None
patterns = self.patterns.copy()
nodes = nodes.copy()
all_matches = []
end_node = None
for p in patterns:
found = False
for n in nodes:
matches, following = p.match(graph, n)
if matches:
found = True
nodes.remove(n)
all_matches.extend(matches)
# Verify all branches end in the same node
if end_node:
if end_node != following:
return [], None
else:
end_node = following
break
if not found:
return [], None
return all_matches, end_node