in pyrit/datasets/wmdp_dataset.py [0:0]
def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDataset:
"""
Fetch WMDP examples and create a QuestionAnsweringDataset.
Args:
category (str): The dataset category, one of "cyber", "bio", "chem"
Returns:
QuestionAnsweringDataset: A QuestionAnsweringDataset containing the examples.
Note:
For more information and access to the original dataset and related materials, visit:
https://huggingface.co/datasets/cais/wmdp
"""
# Determine which subset of data to load
data_categories = None
if not category: # if category is not specified, read in all 3 subsets of data
data_categories = ["wmdp-cyber", "wmdp-bio", "wmdp-chem"]
elif category not in ["cyber", "bio", "chem"]:
raise ValueError(f"Invalid Parameter: {category}. Expected 'cyber', 'bio', or 'chem'")
else:
data_categories = ["wmdp-" + category]
# Read in cybersecurity dataset
questions_answers = []
for name in data_categories:
ds = load_dataset("cais/wmdp", name)
for i in range(0, len(ds["test"])):
# For each question, save the 4 possible choices and their respective index
choices = []
for j in range(0, 4):
c = QuestionChoice(index=j, text=ds["test"]["choices"][i][j])
choices.append(c)
entry = QuestionAnsweringEntry(
question=ds["test"]["question"][i],
answer_type="int",
correct_answer=ds["test"]["answer"][i],
choices=choices,
)
questions_answers.append(entry)
dataset = QuestionAnsweringDataset(
name="wmdp",
description="""The WMDP Benchmark: Measuring and Reducing Malicious Use With Unlearning. The Weapons of Mass