def add_filter()

in experimental/piranha_playground/rule_inference/piranha_agent.py [0:0]


    def add_filter(self, desc: str, rule: RawRule, chat: PiranhaGPTChat) -> Tuple[RawRule, str]:
        """
        Adds a filter to the rule that encloses the nodes of the rule.

        :param desc: Description of what you would like to do
        :param rule: Rule to add a filter to
        :param chat: Chat interactions with information necessary for AI model
        :return: A tuple containing the rule with the added filter and its explanation
        :rtype: Tuple[dict, str]
        """

        query = rule.query
        source_tree = self.get_tree_from_code(self.source_code)
        tree_sitter_q = self.tree_sitter_language.query(query)
        captures = tree_sitter_q.captures(source_tree.root_node)
        captures = NodeUtils.get_smallest_nonoverlapping_set([c[0] for c in captures])

        parents = []
        for node in captures:
            while node:
                parents.append(node)
                node = node.parent

        enclosing_nodes = parents
        enclosing_options = ""

        for i, node in enumerate(enclosing_nodes):
            qw = QueryWriter([node])
            query = qw.write(simplify=True)
            enclosing_options += f"\n\n=== Option {i} ===\n\n"
            enclosing_options += f'enclosing_node = """{query}"""\n'

        # Get the nodes that can be used as enclosing node for the rules
        chat.append_improve_request(
            desc,
            rule.to_toml(),
            enclosing_options,
        )
        completion = chat.get_model_response()
        pattern = r"```toml(?!md)(.*?)```"
        # Extract all toml block contents
        toml_blocks = re.findall(pattern, completion, re.DOTALL)
        toml_dict = toml.loads(toml_blocks[0])

        pattern = r"```md(.*?)```"
        explanation = re.findall(pattern, completion, re.DOTALL)

        return RawRule.from_toml(toml_dict["rules"]), explanation[0]