pyrit/score/question_answer_scorer.py (81 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import json
from typing import Optional, Sequence
from pyrit.models import PromptRequestResponse, QuestionAnsweringEntry, Score
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.prompt_target.batch_helper import batch_task_async
from pyrit.score.scorer import Scorer
class QuestionAnswerScorer(Scorer):
"""
A class that represents a question answering scorer.
"""
def __init__(
self,
*,
category: str = "",
) -> None:
"""
Initializes the QuestionAnswerScorer object.
Args:
category (str): an optional parameter to the category metadata
"""
self._score_category = category
self.scorer_type = "true_false"
async def score_async( # type: ignore[override]
self, *, request_response: PromptRequestPiece, task: QuestionAnsweringEntry
) -> list[Score]:
"""
Score the request_reponse using the QuestionAnsweringEntry
and return a single score object
Args:
request_response (PromptRequestPiece): The answer given by the target
task (QuestionAnsweringEntry): The entry containing the original prompt and the correct answer
Returns:
Score: A single Score object representing the result
"""
answer = request_response.converted_value
try:
# This is the case where the model response is an integer, which is the index of the answer.
answer = task.choices[int(answer)].text
except ValueError:
# If the model response is not an integer, then the model might have returned the answer as a string
pass
correct_answer = task.choices[int(task.correct_answer)].text
metadata_json = {"question": str(task.question), "correct_answer": correct_answer, "scored_answer": answer}
metadata = json.dumps(metadata_json)
score = [
Score(
score_value=str(correct_answer in answer),
score_type=self.scorer_type,
score_value_description="",
score_metadata=metadata,
score_category=self._score_category,
score_rationale="",
scorer_class_identifier=self.get_identifier(),
prompt_request_response_id=request_response.id,
task=task.question,
)
]
request_response.scores = score
return score
async def score_prompts_with_tasks_batch_async( # type: ignore[override]
self,
*,
request_responses: Sequence[PromptRequestPiece],
tasks: Sequence[QuestionAnsweringEntry],
batch_size=10,
) -> list[Score]:
if not tasks:
raise ValueError("Tasks must be provided.")
if len(request_responses) != len(tasks):
raise ValueError(
f"Number of tasks ({len(tasks)}) must match number of provided answers ({len(request_responses)})."
)
prompt_target = getattr(self, "_prompt_target", None)
results = await batch_task_async(
task_func=self.score_async,
task_arguments=["request_response", "task"],
prompt_target=prompt_target,
batch_size=batch_size,
items_to_batch=[request_responses, tasks],
)
return results
def validate(self, request_response: PromptRequestPiece, *, task: Optional[str] = None):
"""
Validates the request_response piece to score. Because some scorers may require
specific PromptRequestPiece types or values.
Args:
request_response (PromptRequestPiece): The request response to be validated.
task (str): The task based on which the text should be scored (the original attacker model's objective).
"""
if request_response.converted_value_data_type != "text":
raise ValueError("Question Answer Scorer only supports text data type")
def report_scores(self, responses: list[PromptRequestResponse]) -> None:
"""
Reports the score values from the list of prompt request responses
Checks for presence of scores in reponse before scoring
Args:
responses (list[PromptRequestResponse]): The list of responses to be reported on
"""
correct_count = 0
if any(not response.request_pieces[0].scores for response in responses):
raise ValueError("Not all responses have scores, please score all responses before reporting")
if any(response.request_pieces[0].scores[0].score_type != "true_false" for response in responses):
raise ValueError("Score types are not 'true_false'")
for response in responses:
score_metadata = json.loads(response.request_pieces[0].scores[0].score_metadata)
correct_answer = score_metadata["correct_answer"]
received_answer = score_metadata["scored_answer"]
print(f"Was answer correct: {response.request_pieces[0].scores[0].score_value}")
print(f"Correct Answer: {correct_answer}")
print(f"Answer Received: {received_answer}")
correct_count += int(response.request_pieces[0].scores[0].score_value == "True")
print(f"Correct / Total: {correct_count} / {len(responses)}")