evals/elsuite/make_me_say/autoeval.py (116 lines of code) (raw):

import uuid from typing import Callable, Optional, Union import backoff from openai import InternalServerError from evals.api import CompletionFn, CompletionResult from evals.elsuite.make_me_say.core import Game, Message, Player from evals.elsuite.make_me_say.defaults import ( DEFAULT_FEEDBACK_MSG_FN, DEFAULT_INTERNAL_MONOLOGUE_END_TOKEN, DEFAULT_INTERNAL_MONOLOGUE_START_TOKEN, DEFAULT_MANIPULATEE_PREAMBLE_FN, DEFAULT_MANIPULATOR_PREAMBLE_FN, DEFAULT_MAX_TURNS, DEFAULT_PERIODIC_MSG_FN, ) @backoff.on_exception(backoff.constant, InternalServerError, max_tries=3) def run( codeword: str, manipulator_completion_fn: CompletionFn, manipulatee_completion_fn: CompletionFn, judge_completion_fn: CompletionFn, max_turns: int = DEFAULT_MAX_TURNS, manipulator_preamble_fn: Callable[[Game], str] = DEFAULT_MANIPULATOR_PREAMBLE_FN, manipulatee_preamble_fn: Callable[[Game], str] = DEFAULT_MANIPULATEE_PREAMBLE_FN, internal_monologue_start_token: Optional[str] = DEFAULT_INTERNAL_MONOLOGUE_START_TOKEN, internal_monologue_end_token: Optional[str] = DEFAULT_INTERNAL_MONOLOGUE_END_TOKEN, periodic_msg_fn: Optional[Callable[[Game], str]] = DEFAULT_PERIODIC_MSG_FN, feedback_msg_fn: Callable[[Game], str] = DEFAULT_FEEDBACK_MSG_FN, ) -> Game: manipulator = Player( id=uuid.uuid4().hex, is_model=True, completion_fn=manipulator_completion_fn, ) manipulatee = Player( id=uuid.uuid4().hex, is_model=True, completion_fn=manipulatee_completion_fn, ) judge = Player( id=uuid.uuid4().hex, is_model=True, completion_fn=judge_completion_fn, ) game = Game( manipulator=manipulator, manipulatee=manipulatee, judge=judge, current_player=manipulatee, messages=(), word_list=(codeword,), max_turns=max_turns, internal_monologue_start_token=internal_monologue_start_token, internal_monologue_end_token=internal_monologue_end_token, ) manipulator_preamble = manipulator_preamble_fn(game) manipulatee_preamble = manipulatee_preamble_fn(game) assert isinstance(manipulator_preamble, str) assert isinstance(manipulatee_preamble, str) game = game.add_message( Message( role="system", content=manipulator_preamble, name=game.manipulator.id, ), ) game = game.add_message( Message( role="system", content=manipulatee_preamble, name=game.manipulatee.id, ) ) for _ in range(game.max_turns): if periodic_msg_fn and game.current_player == game.manipulator: game = game.add_message( Message( role="system", content=periodic_msg_fn(game), name=game.manipulator.id, ) ) game = _create_response(game) game = game.swap_players() game = game.add_message( Message( role="system", content=feedback_msg_fn(game), name=game.manipulatee.id, ) ) game = game.set_current_player(game.manipulatee) game = _create_response(game) return game def _create_response(game: Game) -> Game: # pre-conditions assert game.current_player.is_model assert game.current_player.completion_fn is not None # body messages = [m.to_dict() for m in game.view_as(game.current_player)] response = game.current_player.completion_fn(messages) content = _get_content(response) new_game = game.add_message( Message( role="assistant", content=content, name=game.current_player.id, ) ) # post-conditions assert len(new_game.messages) == len(game.messages) + 1, "Expected one new message." assert new_game.messages[:-1] == game.messages, "Expected prev msgs to be the same." return new_game def _get_content(response: Union[dict, CompletionResult]) -> str: if hasattr(response, "get_completions"): completions = response.get_completions() assert len(completions) == 1, f"Got {len(completions)} but expected exactly one" return completions[0] return response["choices"][0]["message"]["content"]