doc/code/memory/8_seed_prompt_database.py (36 lines of code) (raw):
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.3
# kernelspec:
# display_name: pyrit-dev
# language: python
# name: python3
# ---
# %% [markdown]
# # 8. Seed Prompt Database
#
# Apart from storing results in memory it's also useful to store datasets of seed prompts
# and seed prompt templates that we may want to use at a later point.
# This can help us in curating prompts with custom metadata like harm categories.
# As with all memory, we can use local DuckDBMemory or AzureSQLMemory in Azure to get the
# benefits of sharing with other users and persisting data.
# %%
from pyrit.common import IN_MEMORY, initialize_pyrit
initialize_pyrit(memory_db_type=IN_MEMORY)
# %% [markdown]
# ## Adding prompts to the database
# %%
import pathlib
from pyrit.common.path import DATASETS_PATH
from pyrit.memory import CentralMemory
from pyrit.models import SeedPromptDataset
seed_prompt_dataset = SeedPromptDataset.from_yaml_file(
pathlib.Path(DATASETS_PATH) / "seed_prompts" / "illegal-multimodal-dataset.prompt"
)
print(seed_prompt_dataset.prompts[0])
# Render user-defined values for yaml template
seed_prompt_dataset.render_template_value(stolen_item="a car")
memory = CentralMemory.get_memory_instance()
await memory.add_seed_prompts_to_memory_async(prompts=seed_prompt_dataset.prompts, added_by="test") # type: ignore
# %% [markdown]
# ## Retrieving prompts from the database
#
# First, let's get an idea of what datasets are represented in the database.
# %%
memory.get_seed_prompt_dataset_names()
# %% [markdown]
# The dataset we just uploaded (called "test illegal") is also represented.
# To get all seed prompts from that dataset, we can query as follows:
# %%
dataset_name = "test illegal"
prompts = memory.get_seed_prompts(dataset_name=dataset_name)
print(f"Total number of the prompts with dataset name '{dataset_name}':", len(prompts))
for prompt in prompts:
print(prompt.__dict__)
# %% [markdown]
# ## Adding multimodal seed prompt groups to the database
# In the following example, we will add a seed prompt group containing text, image, audio, and video prompts.
# When we add non-text seed prompts to memory, encoding data will automatically populate in the seed prompt's
# `metadata` field, including `format` (i.e. png, mp4, wav, etc.) as well as additional metadata for audio
# and video files, inclduing `bitrate` (kBits/s as int), `samplerate` (samples/second as int), `bitdepth` (as int),
# `filesize` (bytes as int), and `duration` (seconds as int) if the file type is supported by TinyTag.
# Example suppported file types include: MP3, MP4, M4A, and WAV. These may be helpful to filter for as some targets
# have specific input prompt requirements.
# %%
import pathlib
from pyrit.common.path import DATASETS_PATH
from pyrit.models import SeedPromptGroup
seed_prompt_group = SeedPromptGroup.from_yaml_file(
pathlib.Path(DATASETS_PATH) / "seed_prompts" / "illegal-multimodal-group.prompt"
)
# Render user-defined values for yaml template
seed_prompt_group.render_template_value(stolen_item="a car")
await memory.add_seed_prompt_groups_to_memory(prompt_groups=[seed_prompt_group], added_by="test multimodal illegal") # type: ignore
# %% [markdown]
# ## Retrieving seed prompt groups from the memory with dataset_name as "TestMultimodalTextImageAudioVideo"
# %%
multimodal_dataset_name = "TestMultimodalTextImageAudioVideo"
seed_prompt_groups = memory.get_seed_prompt_groups(dataset_name=multimodal_dataset_name)
print(f"Total number of the seed prompt groups with dataset name '{multimodal_dataset_name}':", len(seed_prompt_groups))
# Retrieving the auto-populated metadata for each seed prompt in the multimodal seed prompt group.
for seed_prompt in seed_prompt_group.prompts:
print(f"SeedPrompt value: {seed_prompt.value}, SeedPrompt metadata: {seed_prompt.metadata}")
# %% [markdown]
# ## Filtering seed prompts by metadata
# %%
# Filter by metadata to get seed prompts in .wav format and sample rate 24000 kBits/s
memory.get_seed_prompts(metadata={"format": "wav", "samplerate": 24000})
# %%
from pyrit.memory import CentralMemory
memory = CentralMemory.get_memory_instance()
memory.dispose_engine()