pyrit/orchestrator/multi_turn/tree_of_attacks_node.py (184 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import json
import logging
import uuid
from typing import Optional
from pyrit.exceptions import (
InvalidJsonException,
pyrit_json_retry,
remove_markdown_json,
)
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import Score, SeedPrompt, SeedPromptGroup
from pyrit.prompt_converter import PromptConverter
from pyrit.prompt_normalizer import PromptNormalizer
from pyrit.prompt_normalizer.prompt_converter_configuration import (
PromptConverterConfiguration,
)
from pyrit.prompt_target import PromptChatTarget
from pyrit.score.scorer import Scorer
logger = logging.getLogger(__name__)
class TreeOfAttacksNode:
"""
Creates a Node to be used with Tree of Attacks with Pruning.
"""
_memory: MemoryInterface
def __init__(
self,
*,
objective_target: PromptChatTarget,
adversarial_chat: PromptChatTarget,
adversarial_chat_seed_prompt: SeedPrompt,
adversarial_chat_prompt_template: SeedPrompt,
adversarial_chat_system_seed_prompt: SeedPrompt,
desired_response_prefix: str,
objective_scorer: Scorer,
on_topic_scorer: Scorer,
prompt_converters: list[PromptConverter],
orchestrator_id: dict[str, str],
memory_labels: Optional[dict[str, str]] = None,
parent_id: Optional[str] = None,
) -> None:
self._objective_target = objective_target
self._adversarial_chat = adversarial_chat
self._objective_scorer = objective_scorer
self._adversarial_chat_seed_prompt = adversarial_chat_seed_prompt
self._desired_response_prefix = desired_response_prefix
self._adversarial_chat_prompt_template = adversarial_chat_prompt_template
self._adversarial_chat_system_seed_prompt = adversarial_chat_system_seed_prompt
self._on_topic_scorer = on_topic_scorer
self._prompt_converters = prompt_converters
self._orchestrator_id = orchestrator_id
self._memory = CentralMemory.get_memory_instance()
self._global_memory_labels = memory_labels or {}
self._prompt_normalizer = PromptNormalizer()
self.parent_id = parent_id
self.node_id = str(uuid.uuid4())
self.objective_target_conversation_id = str(uuid.uuid4())
self.adversarial_chat_conversation_id = str(uuid.uuid4())
self.prompt_sent = False
self.completed = False
self.score: Score = None # Initialize as None since we don't have a score yet
self.off_topic = False
async def send_prompt_async(self, objective: str):
"""Executes one turn of a branch of a tree of attacks with pruning.
This includes a few steps. At first, the red teaming target generates a prompt for the prompt target.
If on-topic checking is enabled, the branch will get pruned if the generated prompt is off-topic.
If it is on-topic or on-topic checking is not enabled, the prompt is sent to the prompt target.
The response from the prompt target is finally scored by the scorer.
"""
self.prompt_sent = True
try:
prompt = await self._generate_red_teaming_prompt_async(objective=objective)
except InvalidJsonException as e:
logger.error(f"Failed to generate a prompt for the prompt target: {e}")
logger.info("Pruning the branch since we can't proceed without red teaming prompt.")
return
if self._on_topic_scorer:
on_topic_score = (await self._on_topic_scorer.score_text_async(text=prompt))[0]
# If the prompt is not on topic we prune the branch.
if not on_topic_score.get_value():
self.off_topic = True
return
seed_prompt_group = SeedPromptGroup(prompts=[SeedPrompt(value=prompt, data_type="text")])
converters = PromptConverterConfiguration(converters=self._prompt_converters)
response = (
await self._prompt_normalizer.send_prompt_async(
seed_prompt_group=seed_prompt_group,
request_converter_configurations=[converters],
conversation_id=self.objective_target_conversation_id,
target=self._objective_target,
labels=self._global_memory_labels,
orchestrator_identifier=self._orchestrator_id,
)
).request_pieces[0]
logger.debug(f"saving score with prompt_request_response_id: {response.id}")
self.score = (
await self._objective_scorer.score_async(
request_response=response,
task=objective,
)
)[0]
self.completed = True
def duplicate(self) -> TreeOfAttacksNode:
"""
Creates a duplicate of the provided instance
with incremented iteration and new conversations ids (but duplicated conversations)
"""
duplicate_node = TreeOfAttacksNode(
objective_target=self._objective_target,
adversarial_chat=self._adversarial_chat,
adversarial_chat_seed_prompt=self._adversarial_chat_seed_prompt,
adversarial_chat_prompt_template=self._adversarial_chat_prompt_template,
adversarial_chat_system_seed_prompt=self._adversarial_chat_system_seed_prompt,
objective_scorer=self._objective_scorer,
on_topic_scorer=self._on_topic_scorer,
prompt_converters=self._prompt_converters,
orchestrator_id=self._orchestrator_id,
memory_labels=self._global_memory_labels,
desired_response_prefix=self._desired_response_prefix,
parent_id=self.node_id,
)
duplicate_node.objective_target_conversation_id = self._memory.duplicate_conversation(
conversation_id=self.objective_target_conversation_id
)
duplicate_node.adversarial_chat_conversation_id = self._memory.duplicate_conversation(
conversation_id=self.adversarial_chat_conversation_id,
)
return duplicate_node
@pyrit_json_retry
async def _generate_red_teaming_prompt_async(self, objective) -> str:
# Use the red teaming target to generate a prompt for the attack target.
# The prompt for the red teaming target needs to include the latest message from the prompt target.
# A special case is the very first message, in which case there are no prior messages
# so we can use the initial red teaming prompt
target_messages = self._memory.get_conversation(conversation_id=self.objective_target_conversation_id)
if not target_messages:
system_prompt = self._adversarial_chat_system_seed_prompt.render_template_value(
objective=objective, desired_prefix=self._desired_response_prefix
)
self._adversarial_chat.set_system_prompt(
system_prompt=system_prompt,
conversation_id=self.adversarial_chat_conversation_id,
orchestrator_identifier=self._orchestrator_id,
labels=self._global_memory_labels,
)
logger.debug("Using the specified initial red teaming prompt for the first turn.")
prompt_text = self._adversarial_chat_seed_prompt.render_template_value(objective=objective)
else:
assistant_responses = [r for r in target_messages if r.request_pieces[0].role == "assistant"]
if not assistant_responses:
logger.error(
f"No assistant responses found in the conversation {self.objective_target_conversation_id}."
)
raise RuntimeError("Cannot proceed without an assistant response.")
target_response = assistant_responses[-1]
target_response_piece = target_response.request_pieces[0]
logger.debug(f"target_response_piece.id: {target_response_piece.id}")
scores = self._memory.get_scores_by_prompt_ids(prompt_request_response_ids=[str(target_response_piece.id)])
if scores:
score = scores[0].get_value()
else:
score = "unavailable"
prompt_text = self._adversarial_chat_prompt_template.render_template_value(
target_response=target_response_piece.converted_value,
objective=objective,
score=str(score),
)
prompt_metadata: dict[str, str | int] = {"response_format": "json"}
seed_prompt_group = SeedPromptGroup(
prompts=[SeedPrompt(value=prompt_text, data_type="text", metadata=prompt_metadata)]
)
adversarial_chat_response = (
await self._prompt_normalizer.send_prompt_async(
seed_prompt_group=seed_prompt_group,
conversation_id=self.adversarial_chat_conversation_id,
target=self._adversarial_chat,
labels=self._global_memory_labels,
orchestrator_identifier=self._orchestrator_id,
)
).get_value()
return self._parse_red_teaming_response(adversarial_chat_response)
def _parse_red_teaming_response(self, red_teaming_response: str) -> str:
# The red teaming response should be in JSON format with two keys: "prompt" and "improvement".
# We need to parse only "prompt" and return its value.
# If the JSON is valid in Markdown format, remove the Markdown formatting
red_teaming_response = remove_markdown_json(red_teaming_response)
try:
red_teaming_response_dict = json.loads(red_teaming_response)
except json.JSONDecodeError:
logger.error(f"The response from the red teaming chat is not in JSON format: {red_teaming_response}")
raise InvalidJsonException(message="The response from the red teaming chat is not in JSON format.")
try:
return red_teaming_response_dict["prompt"]
except KeyError:
logger.error(f"The response from the red teaming chat does not contain a prompt: {red_teaming_response}")
raise InvalidJsonException(message="The response from the red teaming chat does not contain a prompt.")
def __str__(self) -> str:
return (
"TreeOfAttackNode("
f"completed={self.completed}, "
f"score={self.score.get_value()}, "
f"node_id={self.node_id}, "
f"objective_target_conversation_id={self.objective_target_conversation_id})"
)
__repr__ = __str__