pyrit/cli/scanner_config.py (189 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import inspect from copy import deepcopy from importlib import import_module from typing import Any, List, Literal, Optional, Type, get_args import yaml from pydantic import BaseModel, Field, field_validator, model_validator from pyrit.common.initialization import MemoryDatabaseType from pyrit.prompt_converter.prompt_converter import PromptConverter SupportedExecutionTypes = Literal["local"] def load_class(module_name: str, class_name: str, error_context: str) -> Type[Any]: """ Dynamically import a class from a module by name. """ try: mod = import_module(module_name) cls = getattr(mod, class_name) if not inspect.isclass(cls): raise TypeError(f"The attribute {class_name} in module {module_name} is not a class.") except Exception as ex: raise RuntimeError(f"Failed to import {class_name} from {module_name} for {error_context}: {ex}") from ex return cls class DatabaseConfig(BaseModel): """ Configuration for the database used by the scanner. """ db_type: MemoryDatabaseType = Field( ..., alias="type", description=f"Which database to use. Supported values: {list(get_args(MemoryDatabaseType))}", ) memory_labels: dict = Field(default_factory=dict, description="Labels that will be stored in memory to tag runs.") class ScenarioConfig(BaseModel, extra="allow"): """ Configuration for a single scenario orchestrator. """ scenario_type: str = Field( ..., alias="type", description="Scenario orchestrator class/type (e.g. 'PromptSendingOrchestrator')." ) @model_validator(mode="after") def check_scenario_type(self) -> "ScenarioConfig": """ Robustness check to ensure the user actually provided a scenario_type in the YAML. Pydantic already enforces requiredness, but we are adding more checks here. """ if not self.scenario_type: raise ValueError("Scenario 'type' must not be empty.") return self def create_orchestrator( self, objective_target: Any, adversarial_chat: Optional[Any] = None, prompt_converters: Optional[List[Any]] = None, scoring_target: Optional[Any] = None, objective_scorer: Optional[Any] = None, ) -> Any: """ Load and instantiate the orchestrator class, injecting top-level objects (targets, scorers) as needed. """ # Loading the orchestrator class by name, e.g 'RedTeamingOrchestrator' orchestrator_class = load_class( module_name="pyrit.orchestrator", class_name=self.scenario_type, error_context="scenario" ) # Converting scenario fields into a dict for the orchestrator constructor scenario_args = self.model_dump(exclude={"scenario_type"}) scenario_args = deepcopy(scenario_args) # Inspecting the orchestrator constructor so we can inject the optional arguments if they exist constructor_arg_names = [ param.name for param in inspect.signature(orchestrator_class.__init__).parameters.values() ] # Building a map of complex top-level objects that belong outside the scenario complex_args = { "objective_target": objective_target, "adversarial_chat": adversarial_chat, "prompt_converters": prompt_converters, "scoring_target": scoring_target, "objective_scorer": objective_scorer, } # Disallowing scenario-level overrides for these complex args for key in complex_args: if key in scenario_args: raise ValueError(f"{key} must be configured at the top-level of the config, not inside a scenario.") # If the orchestrator constructor expects any of these, inject them for key, value in complex_args.items(): if key in constructor_arg_names and value is not None: scenario_args[key] = value # And the instantiation of the orchestrator try: return orchestrator_class(**scenario_args) except Exception as ex: raise ValueError(f"Failed to instantiate scenario '{self.scenario_type}': {ex}") from ex class TargetConfig(BaseModel): """ Configuration for a prompt target (e.g. OpenAIChatTarget). """ class_name: str = Field(..., alias="type", description="Prompt target class name (e.g. 'OpenAIChatTarget').") def create_instance(self) -> Any: """ Dynamically instantiate the underlying target class. """ target_class = load_class( module_name="pyrit.prompt_target", class_name=self.class_name, error_context="TargetConfig" ) init_kwargs = self.model_dump(exclude={"class_name"}) return target_class(**init_kwargs) class ObjectiveScorerConfig(BaseModel): """ Configuration for an objective scorer """ type: str = Field(..., description="Scorer class (e.g. 'SelfAskRefusalScorer').") def create_scorer(self, scoring_target_obj: Optional[Any]) -> Any: """ Load and instantiate the scorer class. """ scorer_class = load_class(module_name="pyrit.score", class_name=self.type, error_context="objective_scorer") init_kwargs = self.model_dump(exclude={"type"}) signature = inspect.signature(scorer_class.__init__) chat_target_key: str = "chat_target" if chat_target_key in signature.parameters: if scoring_target_obj is None: raise KeyError( "Scorer requires a scoring_target to be defined. " "Alternatively, the adversarial_target can be used for scoring purposes, " "but none was provided." ) init_kwargs[chat_target_key] = scoring_target_obj return scorer_class(**init_kwargs) class ScoringConfig(BaseModel): """ Configuration for the scoring setup, including optional override of the default adversarial chat with a 'scoring_target' and/or an 'objective_scorer'. """ scoring_target: Optional[TargetConfig] = Field( None, description="If provided, use this target for scoring instead of 'adversarial_chat'." ) objective_scorer: Optional[ObjectiveScorerConfig] = Field( None, description="Details for the objective scorer, if any." ) def create_objective_scorer(self, scoring_target_obj: Optional[Any]) -> Optional[Any]: # If the user did not provide an objective_scorer config block (meaning the YAML lacks that section), # we simply return None – no scorer to instantiate. if not self.objective_scorer: return None return self.objective_scorer.create_scorer(scoring_target_obj=scoring_target_obj) class ConverterConfig(BaseModel): """ Configuration for a single prompt converter, e.g. type: "Base64Converter" """ class_name: str = Field(..., alias="type", description="The prompt converter class name (e.g. 'Base64Converter').") def create_instance(self) -> Any: """ Dynamically load and instantiate the converter class """ converter_class = load_class( module_name="pyrit.prompt_converter", class_name=self.class_name, error_context="prompt_converter" ) init_kwargs = self.model_dump(exclude={"class_name"}) return converter_class(**init_kwargs) class ExecutionSettings(BaseModel): """ Configuration for how the scanner is executed (e.g. locally or via AzureML). """ type: SupportedExecutionTypes = Field( "local", description=f"Execution environment. Supported values: {list(get_args(SupportedExecutionTypes))}" ) parallel_nodes: Optional[int] = Field(None, description="How many scenarios to run in parallel.") class ScannerConfig(BaseModel): """ Top-level configuration for the entire scanner. """ datasets: List[str] = Field(..., description="List of dataset YAML paths to load seed prompts from.") scenarios: List[ScenarioConfig] = Field(..., description="List of scenario orchestrators to execute.") objective_target: TargetConfig = Field(..., description="Configuration of the main (objective) chat target.") adversarial_chat: Optional[TargetConfig] = Field( None, description="Configuration of the adversarial chat target (if any)." ) scoring: Optional[ScoringConfig] = Field(None, description="Scoring configuration (if any).") converters: Optional[List[ConverterConfig]] = Field(None, description="List of prompt converters to apply.") execution_settings: ExecutionSettings = Field( default_factory=lambda: ExecutionSettings.model_validate({}), description="Settings for how the scan is executed.", ) database: DatabaseConfig = Field( ..., description="Database configuration for storing memory or results, including memory_labels.", ) @field_validator("objective_target", mode="before") def check_objective_target_is_dict(cls, value): """ Ensure the user actually provides a dict. Pydantic will run this validator before it attempts to parse the value into the TargetConfig model """ if not isinstance(value, dict): raise ValueError( "Field 'objective_target' must be a dictionary.\n" "Example:\n" " objective_target:\n" " type: OpenAIChatTarget" ) return value @model_validator(mode="after") def fill_scoring_target(self) -> "ScannerConfig": """ If config.scoring exists but doesn't explicitly define a scoring_target, default it to the adversarial_chat """ if self.scoring: if self.scoring.scoring_target is None and self.adversarial_chat is not None: self.scoring.scoring_target = self.adversarial_chat return self @classmethod def from_yaml(cls, path: str) -> "ScannerConfig": """ Loads configuration from a YAML file and validates it using Pydantic. """ with open(path, "r", encoding="utf-8") as f: raw_dict = yaml.safe_load(f) return cls(**raw_dict) def create_objective_scorer(self) -> Optional[Any]: """ if there's an objective scorer configured, instantiate it using 'scoring_target' (which might be adversarial_chat). """ if not self.scoring: return None scoring_target = None if self.scoring.scoring_target: scoring_target = self.scoring.scoring_target.create_instance() return self.scoring.create_objective_scorer(scoring_target_obj=scoring_target) def create_prompt_converters(self) -> List[PromptConverter]: """ Instantiates each converter defined in 'self.converters' (if any). Returns a list of converter objects. """ if not self.converters: return [] instances = [] for converter_cfg in self.converters: instances.append(converter_cfg.create_instance()) return instances def create_orchestrators(self, prompt_converters: Optional[List[PromptConverter]] = None) -> List[Any]: """ Helper method to instantiate all orchestrators from the scenario configs, injecting objective_target, adversarial_chat, scoring_target, objective_scorer, etc. """ # Instantiate the top-level targets objective_target_obj = self.objective_target.create_instance() adversarial_chat_obj = None if self.adversarial_chat: adversarial_chat_obj = self.adversarial_chat.create_instance() # If there is a scoring_target or an objective_scorer: scoring_target_obj = None objective_scorer_obj = None if self.scoring: # fill_scoring_target might have already assigned it to self.scoring.scoring_target if self.scoring.scoring_target: scoring_target_obj = self.scoring.scoring_target.create_instance() # create the actual scorer objective_scorer_obj = self.scoring.create_objective_scorer(scoring_target_obj=scoring_target_obj) # Now each scenario can create its orchestrator orchestrators = [] for scenario in self.scenarios: orch = scenario.create_orchestrator( objective_target=objective_target_obj, adversarial_chat=adversarial_chat_obj, prompt_converters=prompt_converters, scoring_target=scoring_target_obj, objective_scorer=objective_scorer_obj, ) orchestrators.append(orch) return orchestrators