experimental/piranha_playground/rule_inference/piranha_agent.py (260 lines of code) (raw):
# Copyright (c) 2023 Uber Technologies, Inc.
#
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
#
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import copy
import difflib
import logging
import multiprocessing
import re
from typing import Dict, List, Optional, Tuple
import attr
import toml
from polyglot_piranha import (PiranhaArguments, PiranhaOutputSummary, Rule,
RuleGraph, execute_piranha)
from piranha_playground.rule_inference.controller import Controller
from piranha_playground.rule_inference.graph_parser import GraphParser
from piranha_playground.rule_inference.piranha_chat import (
PiranhaChatException, PiranhaGPTChat)
from piranha_playground.rule_inference.rule_application import \
run_piranha_with_timeout
from piranha_playground.rule_inference.static_inference import (Inference,
QueryWriter)
from piranha_playground.rule_inference.template_parser import TemplateParser
from piranha_playground.rule_inference.utils.logger_formatter import \
CustomFormatter
from piranha_playground.rule_inference.utils.node_utils import NodeUtils
from piranha_playground.rule_inference.utils.pretty_toml import PrettyTOML
from piranha_playground.rule_inference.utils.rule_utils import (RawRule,
RawRuleGraph)
logger = logging.getLogger("PiranhaChat")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)
class PiranhaAgentError(Exception):
pass
@attr.s
class PiranhaAgent:
"""
An agent that uses OpenAI's chat models for inferring Piranha rules.
The agent takes pairs of source and target codes, finds a transformation rule between them,
and validates the rule's effectiveness by testing if the rule can refactor the source code into the target code.
:ivar source_code: The source code to refactor
:ivar target_code: The target code we want to achieve after refactoring
:ivar language: The programming language of the source code (default is "java")
:ivar hints: Any hints or information that might help in inferring the rules
:ivar explanation: Explanation of the rule generated by GPT
"""
source_code = attr.ib(type=str)
target_code = attr.ib(type=str)
language = attr.ib(default="java")
hints = attr.ib(default="")
explanation = attr.ib(default=None)
def infer_rules_statically(self) -> str:
"""This function creates the first pass of the rule inference process.
It statically infers rules from the example code and returns a TOML representation of the rule graph.
:return: str, string containing the rule in TOML format
"""
parser = TemplateParser(self.language)
source_tree = parser.get_tree_from_code(self.source_code)
target_tree = parser.get_tree_from_code(self.target_code)
rules = {}
finder = GraphParser(source_tree, target_tree)
pairs = finder.parse_templates()
for comment_name, (nodes_before, nodes_after) in pairs.items():
inference_engine = Inference(nodes_before, nodes_after, parser.template_holes)
rule = inference_engine.static_infer()
rules[comment_name] = rule
# build a dict using finder.edges but with the rule names from rule_names
edges = {
rules[from_name].name: [rules[to_name].name for to_name in to_names]
for from_name, to_names in finder.edges.items()
}
edges = [
{"from": k, "to": v, "scope": "File"} for k, v in edges.items() if v != []
]
graph = RawRuleGraph(list(rules.values()), edges)
return graph.to_toml()
def create_chats(self, rules: str) -> List[PiranhaGPTChat]:
"""
Prepare the data for interaction with the AI model.
:param rules: Statically inferred rules in TOML format
:type rules: str
:return: List of chat interactions with information necessary for AI model.
:rtype: List[PiranhaGPTChat]
"""
parser = TemplateParser(self.language)
source_tree = parser.get_tree_from_code(self.source_code, remove_comments=True)
target_tree = parser.get_tree_from_code(self.target_code, remove_comments=True)
source_tree_sexpr = NodeUtils.generate_sexpr(source_tree.root_node, 0)
target_tree_sexpr = NodeUtils.generate_sexpr(target_tree.root_node, 0)
# Create diff between source and target code using difflib
diff = list(
difflib.unified_diff(
self.source_code.splitlines(), self.target_code.splitlines()
)
)
diff = "\n".join(diff)
# diff = self.append_diff_information(diff, source_tree, target_tree)
# Cleanup source
prompt_holes = {
"source_code": self.source_code,
"source_tree": source_tree_sexpr,
"target_tree": target_tree_sexpr,
"diff": diff,
"rules": rules,
"hints": self.hints,
}
# Number of Chat interactions to have with the model
n_samples = 5
chat_interactions = [
PiranhaGPTChat(holes=prompt_holes) for _ in range(n_samples)
]
try:
first_round = chat_interactions[0].get_completion(n_samples=n_samples)
except PiranhaChatException as e:
logger.debug(
f"Chat completion failed with {e}. Trying again with a new chat...\n"
)
return []
for i, response in enumerate(first_round):
# Hack to prevent running the prompt multiple times (it's the same for all samples)
# It is cheaper just to sample OpenAI API
chat_interactions[i].append_system_message(response)
return chat_interactions
def get_explanation(self):
return self.explanation
def general_improvement(self, chat_interactions: List[PiranhaGPTChat]) -> str:
"""
Find a rule generated by chat models that complies with the user specified example.
:param chat_interactions: List of chat sessions for the inference engine to use
:type chat_interactions: List[PiranhaGPTChat]
:return: The rule inferred from GPT
:rtype: str
:raises PiranhaAgentError: If the agent fails to generate a rule after 10 rounds of interaction with GPT-4.
"""
max_rounds = 10
for i in range(max_rounds):
for chat in chat_interactions:
try:
completion = chat.get_model_response()
_, toml_block, explanation = self.validate_rule(completion)
self.explanation = explanation
return toml_block
except PiranhaAgentError as e:
logger.debug(
f"GPT-4 failed to generate a rule. Following up the next round with {e}. Trying again...\n"
)
chat.append_user_followup(str(e))
except PiranhaChatException as e:
logger.debug(
f"Chat completion failed with {e}. Trying again with a new chat...\n"
)
raise PiranhaAgentError(
f"Failed to generate a rule after {max_rounds} rounds of interaction with GPT-4.",
)
def validate_rule(self, completion) -> Tuple[str, str, str]:
"""
Tests if the inferred rule can transform the source code into the target code.
:param completion: Inferred rule from the model
:type completion: str
:return: A tuple containing the file name, TOML block, and the explanation
:rtype: Tuple[str, str, str]
"""
pattern = r"```toml(?!md)(.*?)```"
logger.debug(f"Completion\n: {completion}")
# Extract all toml block contents
toml_blocks = re.findall(pattern, completion, re.DOTALL)
if not toml_blocks:
raise PiranhaAgentError(
"No TOML block provided in the expected output format. "
"Please provide a TOML block with the rule. ```toml ... ```"
)
pattern = r"```md(.*?)```"
explanation = re.findall(pattern, completion, re.DOTALL)
if not explanation:
raise PiranhaAgentError(
"No explanation provided in the expected output format. "
"Please provide an explanation as a markdown block. ```md ... ```"
)
try:
toml_block = (
toml_blocks[0].replace("parenthesized_expression", "condition").strip()
)
logger.debug(f"Generated rule: {toml_block}")
toml_dict = toml.loads(toml_block)
except Exception as e:
raise PiranhaAgentError(
f"Could not create Piranha rule. The TOML block is not valid: {e}. "
)
refactored_code = self.run_piranha(toml_dict)
if not refactored_code:
raise PiranhaAgentError(
"Piranha did not generate any refactored code. Either the query or the filters are incorrect. "
)
if NodeUtils.normalize_code(refactored_code) != NodeUtils.normalize_code(
self.target_code
):
raise PiranhaAgentError(
f"The rule produced wrong code!!! "
f"Expected:\n{self.target_code}\n\n but got:\n{refactored_code}\n\n"
)
pattern = r"<file_name_start>(.*?)<file_name_end>"
file_names = re.findall(pattern, completion, re.DOTALL)
file_name = file_names[0] if file_names else "rule.toml"
return file_name, toml_block, explanation[0]
def run_piranha(self, toml_dict) -> str:
"""
Runs the inferred rule graph by applying it to the source code using Piranha.
:param toml_dict: Inferred rules in TOML format
:type toml_dict: dict
:return: Refactored code as a result of the rule application
:rtype: str
"""
rules = toml_dict.get("rules", [])
if not rules:
raise PiranhaAgentError("TOML does not include any rule specifications.")
try:
raw_graph = RawRuleGraph.from_toml(toml_dict)
logger.debug(f"Raw graph: {raw_graph.to_toml()}")
res, success = run_piranha_with_timeout(
self.source_code, self.language, raw_graph, timeout=5
)
if not success:
if "QueryError" in res:
raise PiranhaAgentError(
f"One of the provided queries is not valid {res}. "
f"Do not use nodes you cannot see in the tree representation. "
f"Make sure you parenthesis are balanced."
)
raise PiranhaAgentError(f"Piranha failed to execute: {res}.")
return res
except multiprocessing.context.TimeoutError:
raise PiranhaAgentError(
"Piranha in infinite loop. Please add a filter or constraint the query. "
"Remember you can only constraint queries with #eq, #not-eq, #match. "
"Otherwise you need to use a [[rules.filters]] with contains or not_contains."
)
def improve_rules(self, rules: List[RawRule], task: str, chat: PiranhaGPTChat) -> Tuple[List[RawRule], List[str]]:
"""
Processes each rule and decides whether it should be improved.
:param rules: Rules to be processed
:param task: Task for which rules are being processed
:param chat: Chat interaction object
:return: A tuple containing a list of updated rules and a list of explanations
"""
updated_rules = []
explanations = []
for rule in rules:
updated_rule, explanation = self.improve_single_rule(rule, task, chat)
updated_rules.append(updated_rule)
explanations.append(explanation)
return updated_rules, explanations
def improve_single_rule(self, rule: RawRule, task: str, chat: PiranhaGPTChat) -> Tuple[RawRule, str]:
"""
Processes a single rule and decides whether it should be improved.
:param rule: Rule to be processed
:param task: Task for which the rule is being processed
:param chat: Chat interaction object
:return: A tuple containing the updated rule and an explanation
"""
rule_str = rule.to_toml()
controller = Controller(chat)
if controller.should_improve_rule(task, rule_str):
option = controller.get_option_for_improvement(rule_str)
if option == "add filter":
return self.add_filter(task, rule, chat)
return rule, ''
def validate_improved_rules(self, updated_rules: List[RawRule], explanations: List[str]) -> str:
"""
Validates the updated rules and generates an explanation.
:param updated_rules: A list of updated rules
:param explanations: A list of explanations for the rules
:return: A validation string
"""
rule_block = "\n".join(
[rule.to_toml() for rule in updated_rules]
)
explanation_block = "\n".join(explanations)
validation = self.validate_rule(
f"<file_name_start>rules.toml<file_name_end> ```toml\n{rule_block}\n``` ```md\n{explanation_block}\n```"
)
self.explanation = "\n".join(explanations)
return validation[1]
def improve_rule_graph(self, task: str, rules: str, option: str) -> str:
"""
Improves the rule by adding a filter to it.
:param task: Description of what you would like to do
:param rules: Rules to improve
:param option: User or general improvements
:return: The improved rule
:raises PiranhaAgentError: If unable to improve the rule
"""
try:
chat_interactions = self.create_chats(rules)
except BaseException as e:
logger.debug(f"Failed to create chat: {e}")
raise PiranhaAgentError(str(e)) from e
if option == "general":
return self.general_improvement(chat_interactions)
graph = RawRuleGraph.from_toml(toml.loads(rules))
for chat in chat_interactions:
try:
updated_rules, explanations = self.improve_rules(graph.rules, task, chat)
return self.validate_improved_rules(updated_rules, explanations)
except Exception as e:
logger.debug(
f"GPT-4 failed to generate a rule. Following up the next round with {e}. Trying again...\n"
)
chat.append_user_followup(str(e))
logger.debug(f"Unable to improve rule {rules}.")
raise PiranhaAgentError("Unable to improve rule.")
def add_filter(self, desc: str, rule: RawRule, chat: PiranhaGPTChat) -> Tuple[RawRule, str]:
"""
Adds a filter to the rule that encloses the nodes of the rule.
:param desc: Description of what you would like to do
:param rule: Rule to add a filter to
:param chat: Chat interactions with information necessary for AI model
:return: A tuple containing the rule with the added filter and its explanation
:rtype: Tuple[dict, str]
"""
query = rule.query
source_tree = self.get_tree_from_code(self.source_code)
tree_sitter_q = self.tree_sitter_language.query(query)
captures = tree_sitter_q.captures(source_tree.root_node)
captures = NodeUtils.get_smallest_nonoverlapping_set([c[0] for c in captures])
parents = []
for node in captures:
while node:
parents.append(node)
node = node.parent
enclosing_nodes = parents
enclosing_options = ""
for i, node in enumerate(enclosing_nodes):
qw = QueryWriter([node])
query = qw.write(simplify=True)
enclosing_options += f"\n\n=== Option {i} ===\n\n"
enclosing_options += f'enclosing_node = """{query}"""\n'
# Get the nodes that can be used as enclosing node for the rules
chat.append_improve_request(
desc,
rule.to_toml(),
enclosing_options,
)
completion = chat.get_model_response()
pattern = r"```toml(?!md)(.*?)```"
# Extract all toml block contents
toml_blocks = re.findall(pattern, completion, re.DOTALL)
toml_dict = toml.loads(toml_blocks[0])
pattern = r"```md(.*?)```"
explanation = re.findall(pattern, completion, re.DOTALL)
return RawRule.from_toml(toml_dict["rules"]), explanation[0]