in drop_eval.py [0:0]
def __call__(self, sampler: SamplerBase) -> EvalResult:
rng = random.Random(self.seed)
def fn(example: dict[str, str]):
stuffing = rng.sample(self.train_samples, self._train_samples_per_prompt)
# prompt = """TASK: Read the provided passage, then identify the correct answer to questions below."""
prompt = """You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below."""
prompt += "\n\n# Examples"
samples = stuffing + [example]
for i, sample in enumerate(samples):
is_test = i == len(stuffing)
prompt += "\n# Your Task\n" if is_test else ""
prompt += f"""
---
{sample["context"]} """
a = sample["completion"]
correct_answers = sample["ref_text"].split("|")
if not is_test:
prompt += a + "\n"
else:
prompt += """\n
Think step by step, then write a line of the form "Answer: $ANSWER" at the end of your response.
"""
prompt_messages = [sampler._pack_message(content=prompt, role="user")]
response_text = sampler(prompt_messages)
match = re.search(ANSWER_PATTERN, response_text)
extracted_answer = match.group(1) if match else response_text
em_score, f1_score = drop_metric(extracted_answer, correct_answers)
matches = [
fuzzy_match(extracted_answer, correct_answer)
for correct_answer in correct_answers
]
extracted_answers = [
extracted_answer for i in range(len(correct_answers)) if matches[i]
]
score = True in matches
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=extracted_answer, role="assistant"),
score=score,
correct_answer=correct_answers,
extracted_answer=extracted_answers,
)
convo = prompt_messages + [dict(content=extracted_answer, role="assistant")]
return SingleEvalResult(
html=html,
score=score,
convo=convo,
metrics={"em_score": em_score, "f1_score": f1_score},
)
results = common.map_with_progress(fn, self.test_samples)
return common.aggregate_results(results)