in experimental/piranha_playground/rule_inference/piranha_agent.py [0:0]
def add_filter(self, desc: str, rule: RawRule, chat: PiranhaGPTChat) -> Tuple[RawRule, str]:
"""
Adds a filter to the rule that encloses the nodes of the rule.
:param desc: Description of what you would like to do
:param rule: Rule to add a filter to
:param chat: Chat interactions with information necessary for AI model
:return: A tuple containing the rule with the added filter and its explanation
:rtype: Tuple[dict, str]
"""
query = rule.query
source_tree = self.get_tree_from_code(self.source_code)
tree_sitter_q = self.tree_sitter_language.query(query)
captures = tree_sitter_q.captures(source_tree.root_node)
captures = NodeUtils.get_smallest_nonoverlapping_set([c[0] for c in captures])
parents = []
for node in captures:
while node:
parents.append(node)
node = node.parent
enclosing_nodes = parents
enclosing_options = ""
for i, node in enumerate(enclosing_nodes):
qw = QueryWriter([node])
query = qw.write(simplify=True)
enclosing_options += f"\n\n=== Option {i} ===\n\n"
enclosing_options += f'enclosing_node = """{query}"""\n'
# Get the nodes that can be used as enclosing node for the rules
chat.append_improve_request(
desc,
rule.to_toml(),
enclosing_options,
)
completion = chat.get_model_response()
pattern = r"```toml(?!md)(.*?)```"
# Extract all toml block contents
toml_blocks = re.findall(pattern, completion, re.DOTALL)
toml_dict = toml.loads(toml_blocks[0])
pattern = r"```md(.*?)```"
explanation = re.findall(pattern, completion, re.DOTALL)
return RawRule.from_toml(toml_dict["rules"]), explanation[0]