doc/code/orchestrators/qa_benchmark_orchestrator.py (80 lines of code) (raw):
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# kernelspec:
# display_name: pyrit-dev
# language: python
# name: pyrit-dev
# ---
# %% [markdown]
# # Q&A Benchmark Orchestrator - optional
#
# The `QuestionAnsweringBenchmarkOrchestrator` can process Q&A datasets and evaluate how good a target is at answering the questions.
# %%
# Import necessary packages
from pyrit.common import IN_MEMORY, initialize_pyrit
from pyrit.datasets import fetch_wmdp_dataset
from pyrit.memory import CentralMemory
from pyrit.models import (
QuestionAnsweringDataset,
QuestionAnsweringEntry,
QuestionChoice,
)
from pyrit.orchestrator import QuestionAnsweringBenchmarkOrchestrator
from pyrit.prompt_target import OpenAIChatTarget
from pyrit.score.question_answer_scorer import QuestionAnswerScorer
# %%
# Initialize PyRIT (load environment files and set central memory instance)
initialize_pyrit(memory_db_type=IN_MEMORY)
# %%
# Set up the Azure OpenAI prompt target
target = OpenAIChatTarget()
# %%
# Create demo dataset for Q/A Model
qa_ds = QuestionAnsweringDataset(
name="demo dataset",
version="1.0",
description="A dataset for testing the question answering model",
author="PyRIT Team",
group="",
source="",
questions=[
QuestionAnsweringEntry(
question="What is the capital of France?",
answer_type="str",
correct_answer="0",
choices=[
QuestionChoice(index=0, text="Paris"),
QuestionChoice(index=1, text="London"),
QuestionChoice(index=2, text="Berlin"),
QuestionChoice(index=3, text="Madrid"),
],
),
QuestionAnsweringEntry(
question="What is the capital of Germany?",
answer_type="str",
correct_answer="2",
choices=[
QuestionChoice(index=0, text="Paris"),
QuestionChoice(index=1, text="London"),
QuestionChoice(index=2, text="Berlin"),
QuestionChoice(index=3, text="Madrid"),
],
),
QuestionAnsweringEntry(
question="Which of the following is a prime number?",
answer_type="str",
correct_answer="3",
choices=[
QuestionChoice(index=0, text="507963"),
QuestionChoice(index=1, text="507962"),
QuestionChoice(index=2, text="507960"),
QuestionChoice(index=3, text="507961"),
],
),
],
)
# Create the scorer
qa_scorer = QuestionAnswerScorer()
# Create the orchestrator with scorer and demo dataset
benchmark_orchestrator = QuestionAnsweringBenchmarkOrchestrator(
objective_target=target, scorers=[qa_scorer], verbose=True
)
# Sends prompts associated with dataset
responses = await benchmark_orchestrator.send_prompts_async(dataset=qa_ds) # type: ignore
# %%
qa_scorer.report_scores(responses=responses)
# %%
# Fetch WMDP dataset for Q/A Model Testing
wmdp_ds = fetch_wmdp_dataset()
wmdp_ds.questions = wmdp_ds.questions[:3]
# Evaluate the Q/A Model response
responses = await benchmark_orchestrator.send_prompts_async(dataset=wmdp_ds) # type: ignore
# %%
# Output if the results are correct
qa_scorer.report_scores(responses=responses)
# %%
# Fetch WMDP dataset for Q/A Model Testing - Chem Subset
wmdp_ds = fetch_wmdp_dataset(category="chem")
wmdp_ds.questions = wmdp_ds.questions[:3]
# Evaluate the Q/A Model response
responses = await benchmark_orchestrator.send_prompts_async(dataset=wmdp_ds) # type: ignore
# %%
# Output if the results are correct
qa_scorer.report_scores(responses=responses)
# %%
# Fetch WMDP dataset for Q/A Model Testing - Bio Subset
wmdp_ds = fetch_wmdp_dataset(category="bio")
wmdp_ds.questions = wmdp_ds.questions[:3]
# Evaluate the Q/A Model response
responses = await benchmark_orchestrator.send_prompts_async(dataset=wmdp_ds) # type: ignore
# %%
# Output if the results are correct
qa_scorer.report_scores(responses=responses)
# %%
# Fetch WMDP dataset for Q/A Model Testing - Cyber Subset
wmdp_ds = fetch_wmdp_dataset(category="cyber")
wmdp_ds.questions = wmdp_ds.questions[:3]
# Evaluate the Q/A Model response
responses = await benchmark_orchestrator.send_prompts_async(dataset=wmdp_ds) # type: ignore
# %%
# Output if the results are correct
qa_scorer.report_scores(responses=responses)
# %%
# Close connection for memory instance
memory = CentralMemory.get_memory_instance()
memory.dispose_engine()