in evals/elsuite/cant_do_that_anymore/eval.py [0:0]
def eval_sample(self, solver: Solver, sample: Any, rng: random.Random):
previous_moves, next_filtered_moves = (
sample["previous_moves"],
sample["next_filtered_moves"],
)
def construct_controller(piece_id_to_instance: Dict[int, Piece]) -> BoardController:
controller = BoardController(
default_board_init,
piece_id_to_instance,
PIECE_STR_TO_ID,
PIECE_ID_TO_STR,
AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR),
)
for move in previous_moves:
controller.update_board(move)
return controller
default_controller = construct_controller(PIECE_ID_TO_INSTANCE)
variant_controller = construct_controller(VARIANT_PIECE_ID_TO_INSTANCE)
# Get solver prediction. Ideally I wouldn't pass the legal_moves to the solvers, they
# should figure them out themselves, but it's necessary for the random solver
def get_solver_pred(
task_description: str,
controller: BoardController,
) -> SolverResult:
task_state = TaskState(
task_description,
messages=construct_messages(previous_moves),
)
return solver(task_state, **{"max_tokens": 4})
solver_result = get_solver_pred(TASK_DESCRIPTION, default_controller)
solver_result_variant = get_solver_pred(TASK_DESCRIPTION_VARIANT, variant_controller)
metrics = {
"move": next_filtered_moves,
"predicted_move": solver_result.output.strip() in next_filtered_moves,
"predicted_move_in_variant": solver_result_variant.output.strip()
in next_filtered_moves,
"num_previous_moves": len(previous_moves),
"previous_moves": previous_moves,
}
# Add violations to metrics
metrics.update(
self.get_violations(
default_controller, solver_result.output, previous_moves, "standard"
)
)
metrics.update(
self.get_violations(
variant_controller, solver_result_variant.output, previous_moves, "variant"
)
)
evals.record.record_metrics(**metrics)