evals/elsuite/bugged_tools/eval.py (210 lines of code) (raw):
import logging
import random
import re
from typing import Any, Sequence, Union
import evals.metrics
from evals.api import CompletionFn
from evals.elsuite.bugged_tools.bugged_tools import ALL_BUGS
from evals.elsuite.bugged_tools.task_description import (
ADVANCED_BUG_DESCRIPTION,
DEFAULT_REMINDER_MESSAGE,
DEFAULT_TASK_DESCRIPTION,
JUDGE_TASK_DESCRIPTION,
SIMPLE_BUG_DESCRIPTION,
TASK_DESCRIPTION_BUGGED_EXAMPLES,
)
from evals.elsuite.bugged_tools.tools import ALL_TOOLS, BuggedTool
from evals.elsuite.bugged_tools.utils import precision_recall_fscore, try_cast_from_str
from evals.elsuite.solver_tools_convo import Runner
from evals.eval import SolverEval
from evals.record import Event
from evals.registry import registry
from evals.solvers.solver import DummySolver, Solver
from evals.task_state import Message
logger = logging.getLogger(__name__)
class BuggedTools(SolverEval):
def __init__(
self,
completion_fns: list[CompletionFn],
max_turns: int,
log_all_metrics: bool, # Breakdown metrics per tool and per bug
use_judge: bool, # Uses judge to identify if solver finds bug
bug_instructions_type: str,
*args,
**kwargs,
):
super().__init__(completion_fns, *args, **kwargs)
self.max_turns = max_turns
self.log_all_metrics = log_all_metrics
self.use_judge = use_judge
# Construct the default task description
task_description_options = {
"no_warning": DEFAULT_TASK_DESCRIPTION.format(bug_description=""),
"simple_warning": DEFAULT_TASK_DESCRIPTION.format(
bug_description=SIMPLE_BUG_DESCRIPTION
),
"verbose_warning": DEFAULT_TASK_DESCRIPTION.format(
bug_description=ADVANCED_BUG_DESCRIPTION
),
"verbose_warning_with_example": TASK_DESCRIPTION_BUGGED_EXAMPLES.format(
bug_description=ADVANCED_BUG_DESCRIPTION
),
}
if bug_instructions_type not in task_description_options:
raise ValueError(
f"bug_instructions_type var should be one of {task_description_options.keys()}"
)
self.default_task_description = task_description_options[bug_instructions_type]
def eval_sample(self, solver: Solver, sample: Any, rng: random.Random):
required_keys = ["task", "answer", "tools", "bugs"]
assert all([i in sample.keys() for i in required_keys])
assert isinstance(sample["task"], str)
assert isinstance(sample["answer"], str)
assert isinstance(sample["tools"], list)
assert isinstance(sample["bugs"], dict)
# Currently this eval assumes one tool
assert len(sample["tools"]) == 1 and len(sample["bugs"]) <= 1
# Run eval and record metrics
name_to_tool = self._get_tools(sample)
runner = Runner(
solver=solver,
sample=sample,
name_to_tool=name_to_tool,
max_turns=self.max_turns,
default_task_description=self.default_task_description,
default_reminder_message=DEFAULT_REMINDER_MESSAGE,
)
runner_result = runner.run()
final_task_state, final_solver_result, metrics = (
runner_result.final_task_state,
runner_result.final_solver_result,
runner_result.metrics,
)
all_messages = final_task_state.messages + [
Message(role="assistant", content=final_solver_result.output)
]
bugs = [i["bugged_func_name"] for i in sample["bugs"].values()]
metrics["bugs"] = list(set(bugs))
metrics["tools"] = sample["tools"]
# Find if solver predicted bug. Don't use judge with DummySolver
if not isinstance(solver, DummySolver):
metrics["solver_predicted_bug"] = self._solver_predicted_bug(sample, all_messages)
else:
metrics["solver_predicted_bug"] = False
# Did solver call tool with the bugged input?
metrics["solver_used_bugged_input"] = self._solver_used_bugged_input(sample, all_messages)
evals.record.record_metrics(**metrics) # type: ignore (evals.record badly hinted)
def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: # type: ignore (evals.record badly hinted)
samples = self.get_samples()
self.eval_all_samples(recorder, samples)
metrics = recorder.get_metrics()
solver_didnt_use_bugged_input_rate = len(
[i for i in metrics if not i["solver_used_bugged_input"]]
) / len(metrics)
task_solved_rate = len([i for i in metrics if i["is_correct"]]) / len(metrics)
min_num_turns = min([i["num_turns"] for i in metrics])
max_num_turns = max([i["num_turns"] for i in metrics])
avg_num_turns = sum([i["num_turns"] for i in metrics]) / len(metrics)
# Calculate success of solver predicting whether tool was buggy
tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore(metrics)
results = {
"f1": f1,
"precision": precision,
"recall": recall,
"accuracy": accuracy,
"tp": tp,
"fp": fp,
"tn": tn,
"fn": fn,
"task_solved_rate": task_solved_rate,
"min_num_turns": min_num_turns,
"max_num_turns": max_num_turns,
"avg_num_turns": avg_num_turns,
"solver_didnt_use_bugged_input_rate": solver_didnt_use_bugged_input_rate,
}
# Breakdown results per type of tool and bug
if self.log_all_metrics:
self._log_additional_metrics(metrics, results)
return results
def _log_additional_metrics(self, metrics: Sequence[Event], results: dict):
"""
Modifies results in-place, breaks results down per tool and per bug
"""
all_tools = list(set([j for i in metrics for j in i["tools"]]))
all_bugs = list(set([j for i in metrics for j in i["bugs"]]))
# Log bug metrics per type of tool
for tool in all_tools:
filtered_metrics = [i for i in metrics if i["tools"][0] == tool]
tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore(
filtered_metrics
)
results[f"tool_{tool}_f1"] = f1
results[f"tool_{tool}_precision"] = precision
results[f"tool_{tool}_recall"] = recall
results[f"tool_{tool}_accuracy"] = accuracy
results[f"tool_{tool}_tp"] = tp
results[f"tool_{tool}_fp"] = fp
results[f"tool_{tool}_tn"] = tn
results[f"tool_{tool}_fn"] = fn
# Log bug metrics per type of bug. Only log accuracy since all examples here are positive (bugged)
for bug in all_bugs:
filtered_metrics = [i for i in metrics if len(i["bugs"]) > 0]
filtered_metrics = [i for i in filtered_metrics if i["bugs"][0] == bug]
tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore(
filtered_metrics
)
results[f"bug_{bug}_accuracy"] = accuracy
def _get_tools(self, sample: Any) -> dict:
"""
Creates dict of tools specified in the sample, mapping name of tool to its
object. Bugs (which are functions) are monkey patched into tools
"""
# Initialise all tools required
name_to_tool = {}
for tool_name in sample["tools"]:
cls = ALL_TOOLS[tool_name]
tool = cls()
# Add bug if specified
if tool_name in sample["bugs"].keys():
bugged_input, bugged_func_name = (
sample["bugs"][tool._name]["bugged_input"],
sample["bugs"][tool._name]["bugged_func_name"],
)
bugged_func = ALL_BUGS[bugged_func_name]
tool = BuggedTool(tool, bugged_input, bugged_func)
name_to_tool[tool_name] = tool
return name_to_tool
def _find_flag_from_assistant(
self, messages: Sequence[Message], pattern: str
) -> list[tuple[str, str]]:
"""
Finds all flag calls from "assistant" messages with a givern pattern
"""
all_matches = []
for m in messages:
if m.role != "assistant":
continue
matches = re.findall(pattern, m.content, re.DOTALL)
all_matches.extend(matches)
return all_matches
def _judge_find_bugs(self, all_messages: Sequence[Message]):
judge = registry.make_completion_fn("gpt-4-0613", extra_options={"temperature": 0.0})
# Convert eval system messages to user
all_messages = [i.to_dict() for i in all_messages if i.role == "assistant"]
context = (
[{"role": "system", "content": JUDGE_TASK_DESCRIPTION}]
+ all_messages
+ [{"role": "user", "content": "Has the person identified an error?"}]
)
response = judge(context).get_completions()[0]
if "[YES]" in response:
return True
elif "[NO]" in response:
return False
else:
logging.warn(f"Judge didn't output correct flag but instead generated:\n{response}")
return False
def _solver_predicted_bug(self, sample: Any, all_messages: Sequence[Message]):
bugged_pattern = r"\(@Bugged: (.+?)\)"
predicted_bugged_tools = self._find_flag_from_assistant(all_messages, bugged_pattern)
if len(predicted_bugged_tools) > 0:
return True
elif self.use_judge:
return self._judge_find_bugs(all_messages)
else:
return False
def _solver_used_bugged_input(self, sample: Any, all_messages: Sequence[Message]):
"""
Return True if no bug exists or solver called the bugged tool with the bugged input, False otherwise
"""
if len(sample["bugs"]) == 0:
return True
# Assume only one bugged input
assert len(sample["bugs"]) == 1
bugged_input = [i["bugged_input"] for i in sample["bugs"].values()][0]
tool_pattern = r"\(@(?!Answer|Bugged)(\w+): (.+?)\)"
tool_calls = self._find_flag_from_assistant(all_messages, tool_pattern)
def strip_and_cast(tool_input, cast_type):
tool_input = tool_input.strip()
# Remove quotes if solver wrapped input in "" or ''
if tool_input.startswith(("'", '"')) and tool_input.endswith(("'", '"')):
tool_input = tool_input[1:-1]
return try_cast_from_str(tool_input, cast_type)
# Get tool inputs and cast to correct type
tool_inputs_used = [i[1] for i in tool_calls]
tool_inputs_used = [strip_and_cast(i, type(bugged_input)) for i in tool_inputs_used]
tool_inputs_used = [i for i in tool_inputs_used if i is not None]
solver_used_bugged_input = bugged_input in tool_inputs_used
return solver_used_bugged_input