def find_lowest_common_ancestor()

in experimental/piranha_playground/rule_inference/utils/node_utils.py [0:0]


    def find_lowest_common_ancestor(nodes: List[Node]) -> Node:
        """
        Find the smallest common ancestor of the provided nodes.

        :param nodes: list of nodes for which to find the smallest common ancestor.
        :return: Node which is the smallest common ancestor.
        """
        # Ensure the list of nodes isn't empty
        assert len(nodes) > 0

        # Prepare a dictionary to map node's id to the node object
        ids_to_nodes = {node.id: node for node in nodes}

        # For each node, follow its parent chain and add each one to the ancestor set and ids_to_nodes map
        ancestor_ids = [set() for _ in nodes]
        for i, node in enumerate(nodes):
            while node is not None:
                ancestor_ids[i].add(node.id)
                ids_to_nodes[node.id] = node
                node = node.parent

        # Get the intersection of all ancestor sets
        common_ancestors_ids = set.intersection(*ancestor_ids)

        # If there are no common ancestors, there's a problem with the input tree
        if not common_ancestors_ids:
            raise ValueError("Nodes have no common ancestor")

        # The LCA is the deepest node, i.e. the one with maximum start_byte
        max_start_byte_id = max(
            common_ancestors_ids, key=lambda node_id: ids_to_nodes[node_id].start_byte
        )

        return ids_to_nodes[max_start_byte_id]