def match()

in tensorwatch/model_graph/hiddenlayer/ge.py [0:0]


    def match(self, graph, nodes):
        if not nodes:
            return [], None
        nodes = nodes if isinstance(nodes, list) else [nodes]
        # If a single node, assume we need to match with its siblings
        if len(nodes) == 1:
            nodes = graph.siblings(nodes[0])
        else:
            # Verify all nodes have the same parent or all have no parent
            parents = [graph.incoming(n) for n in nodes]
            matches = [set(p) == set(parents[0]) for p in parents[1:]]
            if not all(matches):
                return [], None

        # TODO: If more nodes than patterns, we should consider
        #       all permutations of the nodes
        if len(self.patterns) != len(nodes):
            return [], None
        
        patterns = self.patterns.copy()
        nodes = nodes.copy()
        all_matches = []
        end_node = None
        for p in patterns:
            found = False
            for n in nodes:
                matches, following = p.match(graph, n)
                if matches:
                    found = True
                    nodes.remove(n)
                    all_matches.extend(matches)
                    # Verify all branches end in the same node
                    if end_node:
                        if end_node != following:
                            return [], None
                    else:
                        end_node = following
                    break
            if not found:
                return [], None
        return all_matches, end_node