pyrit/prompt_converter/token_smuggling/base.py (31 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import abc import logging from typing import Literal, Tuple from pyrit.models import PromptDataType from pyrit.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) class SmugglerConverter(PromptConverter, abc.ABC): """ Abstract base class for token smuggling converters. Provides the common asynchronous conversion interface and enforces implementation of encode_message and decode_message in subclasses. """ def __init__(self, action: Literal["encode", "decode"] = "encode") -> None: if action not in ["encode", "decode"]: raise ValueError("Action must be either 'encode' or 'decode'") self.action = action async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """ Convert the prompt by either encoding or decoding it based on the specified action. Args: prompt (str): The prompt to be processed. input_type (PromptDataType): Type of input; only "text" is supported. Returns: ConverterResult: The result containing the output text and its type. Raises: ValueError: If the input type is unsupported. """ if not self.input_supported(input_type): raise ValueError("Input type not supported") if self.action == "encode": summary, encoded = self.encode_message(message=prompt) logger.info(f"Encoded message summary: {summary}") return ConverterResult(output_text=encoded, output_type="text") else: decoded = self.decode_message(message=prompt) return ConverterResult(output_text=decoded, output_type="text") def input_supported(self, input_type: PromptDataType) -> bool: """Return True if the input type is 'text'.""" return input_type == "text" def output_supported(self, output_type: PromptDataType) -> bool: """Return True if the output type is 'text'.""" return output_type == "text" @abc.abstractmethod def encode_message(self, *, message: str) -> Tuple[str, str]: """ Encodes the given message. Must be implemented by subclasses. Args: message (str): The message to encode. Returns: Tuple[str, str]: A tuple containing a summary and the encoded message. """ raise NotImplementedError @abc.abstractmethod def decode_message(self, *, message: str) -> str: """ Decodes the given message. Must be implemented by subclasses. Args: message (str): The encoded message. Returns: str: The decoded message. """ raise NotImplementedError