pyrit/prompt_converter/charswap_attack_converter.py (51 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging import math import random import re import string from pyrit.models import PromptDataType from pyrit.prompt_converter import ConverterResult, PromptConverter # Use logger logger = logging.getLogger(__name__) class CharSwapGenerator(PromptConverter): """ A PromptConverter that applies character swapping to words in the prompt to test adversarial textual robustness. """ def __init__(self, *, max_iterations: int = 10, word_swap_ratio: float = 0.2): """ Initializes the CharSwapConverter. Args: max_iterations (int): Number of times to generate perturbed prompts. The higher the number the higher the chance that words are different from the original prompt. word_swap_ratio (float): Percentage of words to perturb in the prompt per iteration. """ super().__init__() # Ensure max_iterations is positive if max_iterations <= 0: raise ValueError("max_iterations must be greater than 0") # Ensure word_swap_ratio is between 0 and 1 if not (0 < word_swap_ratio <= 1): raise ValueError("word_swap_ratio must be between 0 and 1 (exclusive of 0)") self.max_iterations = max_iterations self.word_swap_ratio = word_swap_ratio def input_supported(self, input_type: PromptDataType) -> bool: return input_type == "text" def output_supported(self, output_type: PromptDataType) -> bool: return output_type == "text" def _perturb_word(self, word: str) -> str: """ Perturb a word by swapping two adjacent characters. Args: word (str): The word to perturb. Returns: str: The perturbed word with swapped characters. """ if word not in string.punctuation and len(word) > 3: idx1 = random.randint(1, len(word) - 2) idx_elements = list(word) # Swap characters idx_elements[idx1], idx_elements[idx1 + 1] = ( idx_elements[idx1 + 1], idx_elements[idx1], ) return "".join(idx_elements) return word async def convert_async(self, *, prompt: str, input_type="text") -> ConverterResult: """ Converts the given prompt by applying character swaps. Args: prompt (str): The prompt to be converted. Returns: ConverterResult: The result containing the perturbed prompts. """ if not self.input_supported(input_type): raise ValueError("Input type not supported") # Tokenize the prompt into words and punctuation using regex words = re.findall(r"\w+|\S+", prompt) word_list_len = len(words) num_perturb_words = max(1, math.ceil(word_list_len * self.word_swap_ratio)) # Copy the original word list for perturbation perturbed_word_list = words.copy() # Get random indices of words to undergo swapping random_words_idx = self._get_n_random(0, word_list_len, num_perturb_words) # Apply perturbation by swapping characters in the selected words for idx in random_words_idx: perturbed_word_list[idx] = self._perturb_word(perturbed_word_list[idx]) # Join the perturbed words back into a prompt new_prompt = " ".join(perturbed_word_list) # Clean up spaces around punctuation output_text = re.sub(r'\s([?.!,\'"])', r"\1", new_prompt).strip() return ConverterResult(output_text=output_text, output_type="text") def _get_n_random(self, low: int, high: int, n: int) -> list: """ Utility function to generate random indices. Words at these indices will be subjected to perturbation. """ result = [] try: result = random.sample(range(low, high), n) except ValueError: logger.debug(f"[CharSwapConverter] Sample size of {n} exceeds population size of {high - low}") return result