project/nanoeval/nanoeval/solvers/computer_tasks/solver.py (290 lines of code) (raw):
from __future__ import annotations
import asyncio
import itertools
import os
import time
from abc import ABC, abstractmethod
from asyncio import CancelledError
from collections import defaultdict
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, Generator, Sequence
import numpy as np
import structlog.stdlib
from pydantic import BaseModel
from typing_extensions import override
import chz
import nanoeval
from nanoeval.asyncio_utils import HasAsyncContextManager, generator_with_cleanup
from nanoeval.eval import RetryableSystemError
from nanoeval.metrics.agents import get_summary_error_aware
from nanoeval.recorder import RecorderProtocol, get_recorder
from nanoeval.solvers.computer_tasks.steps import (
FinalResult,
FinalResultSuccessful,
FinalResultWithException,
Step,
)
from nanoeval.solvers.computer_tasks.task import ComputerTask, Grade
logger = structlog.stdlib.get_logger(component=__name__)
@chz.chz
class PythonCodingSolver(ABC, HasAsyncContextManager):
"""
A solver for tasks that run on a container.
"""
@abstractmethod
def shortname(self) -> str:
pass
@abstractmethod
def run(self, task: ComputerTask) -> AsyncGenerator[Step | FinalResult, None]:
"""
Runs the solver on the given task.
"""
pass
@chz.chz
class DummyPythonCodingSolver(PythonCodingSolver):
"""
A dummy solver that always returns correct. NOTE that this doesn't exercise
environment setup - for that, something like the ACE gold solver (to be written)
should be used.
"""
@override
def shortname(self) -> str:
return "dummy"
@override
async def run(self, task: ComputerTask) -> AsyncGenerator[Step | FinalResult, None]:
del task
yield FinalResultSuccessful(
grade=Grade(score=1, grader_log="Dummy solver always returns correct")
)
def strip_all_metadata(
convo: BaseModel, allowed_metadata_fields: list[str] | None = None
) -> BaseModel:
"""
Strip all metadata from a conversation. We can't use evallib's remove_all_metadata
function because it recently introduced a change that enforces metadata is
serializable, but evidently some UUIDs end up in metadata eventually.
"""
# Delete all the metadata from the conversation and messages.
if allowed_metadata_fields is None:
allowed_metadata_fields = []
convo = convo.model_copy(
update={
"metadata": {k: v for k, v in convo.metadata.items() if k in allowed_metadata_fields} # type: ignore
}
)
for msg in convo.messages: # type: ignore
msg.metadata = {k: v for k, v in msg.metadata.items() if k in allowed_metadata_fields}
return convo
@contextmanager
def simple_timer(name: str) -> Generator[None, None, None]:
logger.info(f"[{name}] started")
start = time.monotonic()
try:
yield
logger.info(f"[{name}] finished", elapsed=time.monotonic() - start)
except Exception:
logger.warning(f"[{name}] failed", elapsed=time.monotonic() - start)
raise
logged_messages = set()
def _log(
recorder: RecorderProtocol,
step: Step | FinalResult,
record_pretty: bool = False, # records each message as a pretty-printed sampling. may cut off certain messages or information.
) -> None:
"""
Log a data point to the recorder.
"""
assert recorder is not None
if not record_pretty:
# Clean all metadata to make the recording shorter.
if step.convo:
# TODO(kevinliu) restore these
step.convo = strip_all_metadata(step.convo)
# For intermediate steps, only log the last 5 messages of the convo to save space.
# The final result will always have the full conversation.
if isinstance(step, Step):
step.convo = step.convo.model_copy()
step.convo.messages = step.convo.messages[-5:]
try:
recorder.record_extra(data=step.model_dump(mode="json"))
except Exception:
logger.warning("Failed to record extra data", exc_info=True)
else:
# Log all messages, but skip ones we've already recorded
if step.convo is not None:
for message in step.convo.messages[-15:]:
if message.id in logged_messages:
continue
else:
logged_messages.add(message.id)
try:
record_message(recorder, message)
except Exception as e:
logger.warning("Failed to record message", error=e)
def record_message(recorder: RecorderProtocol, message: Any) -> None:
# TODO(kevinliu/extract) fix logging here
text = str(message)
author = str(message.author.role)
status = (str(message.status),)
end_turn = (message.end_turn,)
recipient = message.recipient
prompt = f"""
Author: {author}
Status: {status}
End Turn: {end_turn}
Recipient: {recipient}
"""
sampled = text
recorder.record_sampling(
prompt=prompt,
sampled=sampled,
)
def _log_results(task: ComputerTask, result: FinalResult) -> None:
"""
Log the results of a task.
"""
recorder = get_recorder()
score = result.grade.score
log = result.grade.grader_log
recorder.record_sampling(
prompt="",
sampled=log,
sample_id=task.question_id,
group_id=str(task.attempt_id),
)
recorder.record_match(
correct=bool(score),
group_id=str(task.attempt_id),
)
@chz.chz
class PythonCodingEval(nanoeval.Eval[ComputerTask, FinalResult]):
solver: PythonCodingSolver = DummyPythonCodingSolver()
n_tries: int = 1
record_pretty: bool = False
log_at_end: bool = False
@override
@asynccontextmanager
async def _context(self) -> AsyncGenerator[None, None]:
async with self.solver:
yield
@abstractmethod
async def get_instances(self) -> Sequence[ComputerTask]:
pass
@override
async def get_tasks(self) -> Sequence[ComputerTask]:
questions = await self.get_instances()
tasks = []
for attempt_idx, (_q_idx, question) in itertools.product(
range(self.n_tries), enumerate(questions)
):
tasks.append(
question.model_copy(
update=dict(attempt_id=attempt_idx, question_id=question.question_id)
)
)
return tasks
async def _evaluate_inner(self, task: ComputerTask) -> FinalResult:
recorder = get_recorder()
try:
async with generator_with_cleanup(self.solver.run(task)) as gen:
async for step in gen:
assert not isinstance(step, FinalResultWithException), (
"FinalResultWithException has been deprecated"
)
await asyncio.to_thread(
_log,
recorder,
step,
self.record_pretty,
)
if isinstance(step, FinalResultSuccessful):
logger.info(
"Final result: %s (%d messages)",
step.correct,
len(step.convo.messages) if step.convo else 0,
)
# For compatibility with simple evalboard vis, log a match.
await asyncio.to_thread(recorder.record_match, correct=step.correct)
return step
except CancelledError as e:
logger.exception("Cancelled error detected - this is clearly a bug")
raise RetryableSystemError("Cancelled error detected - this is clearly a bug") from e
raise ValueError("Solver did not return a final result! This is a programming error.")
@override
async def evaluate(self, task: ComputerTask) -> FinalResult:
# print machine statistics, useful for debugging in a multiprocess setting
logger.info("PID: %d", os.getpid())
logger.info("To dump stack traces: $ py-spy dump --pid %d", os.getpid())
logger.info("PythonCodingEval.evaluate() started")
res = await self._evaluate_inner(task)
if self.log_at_end:
logger.info("PythonCodingEval.evaluate() logging")
await asyncio.to_thread(_log_results, task, res)
logger.info("PythonCodingEval.evaluate() finished")
return res
def process_invalid(self, task: ComputerTask) -> FinalResult:
return FinalResultSuccessful(grade=Grade(score=0, grader_log="Task was invalid"))
@override
async def update_progress(
self,
partial_results: list[
tuple[ComputerTask, FinalResultSuccessful | FinalResultWithException]
],
pbar: Any,
) -> None:
summary: dict[str, Any] = {
"num_correct": 0,
"num_incorrect": 0,
"num_incorrect_with_error": 0,
"num_incorrect_max_steps_reached": 0,
"error_breakdown": defaultdict(int),
}
for _task, result in partial_results:
if result.correct:
summary["num_correct"] += 1
else:
summary["num_incorrect"] += 1
if isinstance(result, FinalResultWithException):
summary["error_breakdown"][result.exception] += 1
summary["num_incorrect_with_error"] += 1
elif result.max_steps_reached:
summary["num_incorrect_max_steps_reached"] += 1
pbar.set_postfix(
corr=summary["num_correct"],
errs=summary["num_incorrect_with_error"],
fail=summary["num_incorrect"] - summary["num_incorrect_with_error"],
)
def _get_convo_len_stats(
self, results: list[tuple[ComputerTask, FinalResult | RetryableSystemError]]
) -> dict[str, Any]:
"""
Get conversation length statistics.
"""
completions = [result for _, result in results if isinstance(result, FinalResultSuccessful)]
if not completions:
return {}
frac_correct = sum(1 for result in completions if result.correct) / len(completions)
incorrect_completions = [result for result in completions if not result.correct]
frac_max_time = sum(1 for result in incorrect_completions if result.max_time_reached) / len(
completions
)
frac_max_steps = sum(
1 for result in incorrect_completions if result.max_steps_reached
) / len(completions)
frac_max_tokens = sum(
1 for result in incorrect_completions if result.max_tokens_reached
) / len(completions)
frac_model_ended = sum(
1
for result in incorrect_completions
if not (
result.max_time_reached or result.max_steps_reached or result.max_tokens_reached
)
) / len(completions)
convos = [result.convo for result in completions if result.convo is not None]
summary_dict = {
"frac_correct": frac_correct,
"frac_max_time": frac_max_time,
"frac_max_steps": frac_max_steps,
"frac_max_tokens": frac_max_tokens,
"frac_model_ended": frac_model_ended,
}
if not convos:
return summary_dict
convo_lens = [len(convo.messages) for convo in convos]
# compute percentiles
convo_lens = np.array(convo_lens)
percentiles = np.percentile(convo_lens, [25, 50, 75])
summary_dict["convo_len_percentiles"] = percentiles.tolist() # type: ignore
return {
"frac_correct": frac_correct,
"frac_max_time": frac_max_time,
"frac_max_steps": frac_max_steps,
"frac_max_tokens": frac_max_tokens,
"frac_model_ended": frac_model_ended,
"convo_len_percentiles": percentiles.tolist(),
}
@override
async def get_full_summary(
self, results: list[tuple[ComputerTask, FinalResult | RetryableSystemError]]
) -> dict[str, Any]:
"""
How are results classified?
- FinalResultSuccessful -> goes in correct/incorrect
- FinalResultWithException -> shouldn't exist anymore
- RetryableSystemError -> marked as has_error, ignored in default metrics, but counted in metrics_including_errors
"""
for _, result in results:
assert not isinstance(result, FinalResultWithException), (
"FinalResultWithException has been deprecated in favor of nanoeval system retries"
)
summary = await asyncio.to_thread(
get_summary_error_aware,
[
(
task,
(
result.correct
if isinstance(result, (FinalResultSuccessful, FinalResultWithException))
else result
),
)
for task, result in results
],
)
summary["length_stats"] = self._get_convo_len_stats(results)
return summary