evals/elsuite/solver_tools_convo.py (181 lines of code) (raw):
import copy
import logging
import re
from dataclasses import dataclass
from typing import Any, Optional
from evals.elsuite.bugged_tools.tools import Tool, ToolTaskState
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import Message, TaskState
logger = logging.getLogger(__name__)
@dataclass
class ToolCall:
tool_name: str
input: str
output: Any
@dataclass
class ParsedSolverResult:
tool_calls: list[ToolCall]
final_answer: Optional[str]
@dataclass
class RunnerResult:
final_task_state: ToolTaskState
final_solver_result: SolverResult
metrics: dict
class Runner:
def __init__(
self,
solver: Solver,
sample: Any,
name_to_tool: dict,
max_turns: int,
default_task_description: str,
default_reminder_message: str,
):
self.solver = solver
self.sample = sample
self.name_to_tool = name_to_tool
self.max_turns = max_turns
self.default_task_description = default_task_description
self.default_reminder_message = default_reminder_message
def run(self) -> RunnerResult:
# Prepare initial task state
tools = self.name_to_tool.values()
tool_names_and_descriptions = self._get_tool_names_and_descriptions(tools)
task_description = self.default_task_description.format(
tool_names_and_descriptions=tool_names_and_descriptions
)
task_message = self.sample["task"]
messages = [
Message(role="user", content=task_message),
]
task_state = TaskState(
task_description=task_description,
messages=messages,
current_state=None,
)
# Loops until solver completes task or hits turn limit
turn = 0
final_answer = None
while turn < self.max_turns:
# Get result from solver
solver_result = self.solver(task_state)
parsed_solver_result = self._parse_solver_result(solver_result)
final_answer = parsed_solver_result.final_answer
# If solver failed to call tool or give final answer, prompt them to try again
if parsed_solver_result.tool_calls == [] and final_answer is None:
content = self.default_reminder_message
task_state = self._add_eval_message(task_state, solver_result, content=content)
turn += 1
continue
if final_answer is not None:
return self._finish_run(task_state, solver_result, final_answer, turn)
# Run tools. If solver gave tool incorrect input, prompt them to try again.
assert parsed_solver_result.tool_calls != []
tool_outputs = [self._run_tool_call(i) for i in parsed_solver_result.tool_calls]
if any([i is None for i in tool_outputs]):
content = self.default_reminder_message
task_state = self._add_eval_message(task_state, solver_result, content=content)
turn += 1
continue
# Add user message containing tool outputs
task_state = self._add_tool_outputs(task_state, solver_result, tool_outputs)
turn += 1
return self._finish_run(task_state, solver_result, None, turn)
def _get_tool_names_and_descriptions(self, tools: list[Tool]):
"""
Given sequence of tools, creates a string of each tools name
and description, each tool's info separated by a newline
"""
s = ""
for tool in tools:
s += f"{tool._name}: {tool._desc}\n"
return s
def _parse_solver_result(self, solver_result: SolverResult) -> ParsedSolverResult:
output = solver_result.output
tool_calls = self._parse_tool_calls(output)
final_answer = self._parse_final_answer(output)
return ParsedSolverResult(tool_calls=tool_calls, final_answer=final_answer)
def _parse_tool_calls(self, output: str) -> Optional[list[ToolCall]]:
tool_message_matches = self._find_tool_messages(output)
if tool_message_matches == []:
return []
tool_calls = []
for tool_name, tool_message in tool_message_matches:
# Log warning if solver calls a tool that doesn't exist
try:
self.name_to_tool[tool_name]
except KeyError:
logger.warn(f"Solver tried to call '{tool_name}' tool which doesn't exist!")
continue
tool_call = ToolCall(tool_name=tool_name, input=tool_message, output=None)
tool_calls.append(tool_call)
return tool_calls
def _find_tool_messages(self, text: str) -> list[tuple[str, str]]:
"""
Finds all tool calls, which are formatted [NAME: INPUT],
where NAME != "Answer" and NAME != "Bugged"
"""
pattern = r"\(@(?!Answer|Bugged)(\w+): (.+?)\)"
matches = re.findall(pattern, text, re.DOTALL)
return matches
def _parse_final_answer(self, output: str) -> Optional[str]:
"""
If a final answer exists of form [Answer: OUTPUT], returns the output,
otherwise returns None
"""
match = re.search(r"\(@Answer: (.*?)\)", output, re.DOTALL)
return match.group(1) if match else None
def _run_tool_call(self, tool_call: ToolCall) -> ToolCall:
# Prepare task state
tool_name = tool_call.tool_name
tool = self.name_to_tool[tool_name]
tool_input = tool_call.input
tool_desc = self.name_to_tool[tool_name]._desc
# Remove quotes if solver wrapped input
if tool_input.startswith(("'", '"')) and tool_input.endswith(("'", '"')):
tool_input = tool_input[1:-1]
task_description = (
f"Your name is {tool_name}. A description of your purpose is shown below:\n{tool_desc}"
)
messages = [Message(role="user", content=tool_input)]
task_state = ToolTaskState(
task_description=task_description, messages=messages, current_state=None
)
try:
out = tool(task_state)
except (TypeError, ValueError, IndexError):
out = None
if out is None:
return None
tool_call.output = out.output
return tool_call
def _add_eval_message(
self,
task_state: TaskState,
solver_output: SolverResult,
content: str,
) -> TaskState:
messages = copy.deepcopy(task_state.messages)
messages.append(Message(role="assistant", content=solver_output.output))
# NOTE: we assume that the order of tool_outputs is the same as the order of tool_calls
messages.append(Message(role="user", content=content))
new_task_state = TaskState(
task_description=task_state.task_description,
messages=messages,
current_state=None,
)
return new_task_state
def _add_tool_outputs(
self,
task_state: TaskState,
solver_output: SolverResult,
tool_outputs: list[ToolCall],
) -> TaskState:
content = ""
for tool_output in tool_outputs:
name = tool_output.tool_name
input = tool_output.input
output = tool_output.output
content += f"{name} output on input {input}: {output}\n"
return self._add_eval_message(task_state, solver_output, content)
def _finish_run(
self,
final_task_state: TaskState,
solver_result: SolverResult,
final_answer: Optional[str],
turn: int,
) -> RunnerResult:
expected_answer = self.sample["answer"]
is_correct = False
if final_answer is not None:
final_answer = final_answer.lower().strip()
# Remove quotes if solver wrapped input
if final_answer.startswith(("'", '"')) and final_answer.endswith(("'", '"')):
final_answer = final_answer[1:-1]
is_correct = final_answer == expected_answer.lower().strip()
metrics = {
"is_correct": is_correct,
"num_turns": turn + 1, # zero-indexed,
}
return RunnerResult(
final_task_state,
solver_result,
metrics,
)