def _eval_retrieval_sample()

in evals/elsuite/skill_acquisition/eval.py [0:0]


    def _eval_retrieval_sample(self, solver: Solver, sample: Dict, *_) -> Dict[str, Any]:
        """Evaluates the given sample using retrieval. The retrieval logic is implemented in the _conversation_loop function.

        Args:
            solver (Solver): any compatible solver, instantiated just for this sample.
            sample (Dict): input to evaluate on.

        Returns:
            Dict[str, Any]: metrics collected during evaluation.
        """
        files_available_paths = [
            self.knowledge_base_directory / file for file in self.files_available
        ]
        assert all([file.exists() for file in files_available_paths])
        task_state = TaskState(
            task_description=self.task_description,
            messages=[Message(**msg) for msg in sample["input"]],
            current_state={"files": files_available_paths},
        )

        output, metrics = self._conversation_loop(solver, task_state)

        if answer_detected(output):
            answer = process_answer(output)
            logging.debug(f"Model answered {answer}")
        elif output == "Context length exceeded.":
            answer = "NO ANSWER DETECTED"
            logger.warn("Current interaction exceeded model context length.")
        else:
            answer = "NO ANSWER DETECTED"
            logging.debug(f"Model timed out after {metrics['current_replies']} replies.")

        picked = evals.record_and_check_match(
            prompt=sample["input"],
            sampled=answer,
            expected=[sample["ideal"]],
        )

        out_obj = {
            "prompt": sample["input"],
            "raw_output": output,
            "parsed_output": answer,
            "expected": [sample["ideal"]],
            "correct": picked is not None,
            "bleu": get_bleu_score(sample["ideal"], answer),
            "ctx_len_exceeded": output == "Context length exceeded.",
            "interaction_timed_out": metrics["current_replies"] >= self.max_replies,
            "question_type": get_question_type(sample["input"][-1]["content"]),
            "lesson_retrieval_calls": metrics["lesson_retrieval_calls"],
            "correct_retrieval_calls": metrics["correct_retrieval_calls"],
            "invalid_retrieval_calls": metrics["total_retrieval_calls"]
            - metrics["correct_retrieval_calls"],
            "total_retrieval_calls": metrics["total_retrieval_calls"],
        }
        return out_obj