experimental/piranha_playground/rule_inference/controller.py (44 lines of code) (raw):
# Copyright (c) 2023 Uber Technologies, Inc.
#
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
#
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict
import attr
from piranha_playground.rule_inference.piranha_chat import PiranhaGPTChat
# Define the option constants
ANSWER_OPTIONS = ["yes", "no"]
IMPROVEMENT_OPTIONS = ["add filter", "change query", "change replacement", "give up"]
JSON_FILE_FORMAT = '{{"reasoning": "<your reasoning>", "answer": "{options}"}}'
class ControllerError(Exception):
"""Custom Exception class for handling Controller specific errors."""
@attr.s
class Controller:
"""A controller that decides what the AI model should do next, given the user request.
:param PiranhaGPTChat chat: an instance of the PiranhaGPTChat model.
"""
chat = attr.ib(type=PiranhaGPTChat)
def get_model_selection(self, task_description: str, options: list) -> str:
"""Send the task description to the chat model and get the course of action the model has chosen.
:param str task_description: Description of the task to be sent to the model.
:param list options: List of valid options for the user's response.
:return: The user's response as returned by the model.
:rtype: str
"""
self.chat.append_user_followup(task_description)
n_tries = 3
for _ in range(n_tries):
try:
completion = self.chat.get_model_response()
completion = json.loads(completion)
answer = completion.get("answer")
if answer not in options:
raise ControllerError(
f"Invalid answer: {answer}. Expected one of: {', '.join(options)}."
)
return answer
except json.JSONDecodeError as e:
self.chat.append_user_followup(
f"Error: {e}. Please respond in format: {JSON_FILE_FORMAT}\n"
f"Valid 'answer' options: {options}"
)
raise ControllerError(f"Failed to get valid answer after {n_tries} tries.")
def should_improve_rule(self, task: str, rule: str) -> bool:
"""Determines if a rule should be improved.
:param str task: Task requested by the user.
:param str rule: Rule to be checked for improvements.
:return: Whether the model decides to improve the rule.
:rtype: bool
"""
task_description = (
f"User requests improvements for: '{task}'.\n"
f"Does the rule below need changes?\n\n{rule}\n"
f"Response format: {JSON_FILE_FORMAT.format(options='yes/no')}. "
)
return self.get_model_selection(task_description, ANSWER_OPTIONS) == "yes"
def get_option_for_improvement(self, rule: str) -> str:
"""Asks the model to choose an improvement option for the given rule.
:param str rule: Rule to be improved.
:return: The user's response as returned by the model.
:rtype: str
"""
task_description = (
f"You opted to improve this rule:\n{rule}\n"
f"Choose an improvement option and state your reasoning and selection. "
f"Response format: {JSON_FILE_FORMAT.format(options='/'.join(IMPROVEMENT_OPTIONS))}."
)
return self.get_model_selection(task_description, IMPROVEMENT_OPTIONS)