def eval_sample()

in evals/elsuite/cant_do_that_anymore/eval.py [0:0]


    def eval_sample(self, solver: Solver, sample: Any, rng: random.Random):
        previous_moves, next_filtered_moves = (
            sample["previous_moves"],
            sample["next_filtered_moves"],
        )

        def construct_controller(piece_id_to_instance: Dict[int, Piece]) -> BoardController:
            controller = BoardController(
                default_board_init,
                piece_id_to_instance,
                PIECE_STR_TO_ID,
                PIECE_ID_TO_STR,
                AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR),
            )
            for move in previous_moves:
                controller.update_board(move)
            return controller

        default_controller = construct_controller(PIECE_ID_TO_INSTANCE)
        variant_controller = construct_controller(VARIANT_PIECE_ID_TO_INSTANCE)

        # Get solver prediction. Ideally I wouldn't pass the legal_moves to the solvers, they
        # should figure them out themselves, but it's necessary for the random solver
        def get_solver_pred(
            task_description: str,
            controller: BoardController,
        ) -> SolverResult:
            task_state = TaskState(
                task_description,
                messages=construct_messages(previous_moves),
            )
            return solver(task_state, **{"max_tokens": 4})

        solver_result = get_solver_pred(TASK_DESCRIPTION, default_controller)
        solver_result_variant = get_solver_pred(TASK_DESCRIPTION_VARIANT, variant_controller)

        metrics = {
            "move": next_filtered_moves,
            "predicted_move": solver_result.output.strip() in next_filtered_moves,
            "predicted_move_in_variant": solver_result_variant.output.strip()
            in next_filtered_moves,
            "num_previous_moves": len(previous_moves),
            "previous_moves": previous_moves,
        }

        # Add violations to metrics
        metrics.update(
            self.get_violations(
                default_controller, solver_result.output, previous_moves, "standard"
            )
        )
        metrics.update(
            self.get_violations(
                variant_controller, solver_result_variant.output, previous_moves, "variant"
            )
        )

        evals.record.record_metrics(**metrics)