pyrit/cli/scanner_config.py (189 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
from copy import deepcopy
from importlib import import_module
from typing import Any, List, Literal, Optional, Type, get_args
import yaml
from pydantic import BaseModel, Field, field_validator, model_validator
from pyrit.common.initialization import MemoryDatabaseType
from pyrit.prompt_converter.prompt_converter import PromptConverter
SupportedExecutionTypes = Literal["local"]
def load_class(module_name: str, class_name: str, error_context: str) -> Type[Any]:
"""
Dynamically import a class from a module by name.
"""
try:
mod = import_module(module_name)
cls = getattr(mod, class_name)
if not inspect.isclass(cls):
raise TypeError(f"The attribute {class_name} in module {module_name} is not a class.")
except Exception as ex:
raise RuntimeError(f"Failed to import {class_name} from {module_name} for {error_context}: {ex}") from ex
return cls
class DatabaseConfig(BaseModel):
"""
Configuration for the database used by the scanner.
"""
db_type: MemoryDatabaseType = Field(
...,
alias="type",
description=f"Which database to use. Supported values: {list(get_args(MemoryDatabaseType))}",
)
memory_labels: dict = Field(default_factory=dict, description="Labels that will be stored in memory to tag runs.")
class ScenarioConfig(BaseModel, extra="allow"):
"""
Configuration for a single scenario orchestrator.
"""
scenario_type: str = Field(
..., alias="type", description="Scenario orchestrator class/type (e.g. 'PromptSendingOrchestrator')."
)
@model_validator(mode="after")
def check_scenario_type(self) -> "ScenarioConfig":
"""
Robustness check to ensure the user actually provided a scenario_type in the YAML.
Pydantic already enforces requiredness, but we are adding more checks here.
"""
if not self.scenario_type:
raise ValueError("Scenario 'type' must not be empty.")
return self
def create_orchestrator(
self,
objective_target: Any,
adversarial_chat: Optional[Any] = None,
prompt_converters: Optional[List[Any]] = None,
scoring_target: Optional[Any] = None,
objective_scorer: Optional[Any] = None,
) -> Any:
"""
Load and instantiate the orchestrator class,
injecting top-level objects (targets, scorers) as needed.
"""
# Loading the orchestrator class by name, e.g 'RedTeamingOrchestrator'
orchestrator_class = load_class(
module_name="pyrit.orchestrator", class_name=self.scenario_type, error_context="scenario"
)
# Converting scenario fields into a dict for the orchestrator constructor
scenario_args = self.model_dump(exclude={"scenario_type"})
scenario_args = deepcopy(scenario_args)
# Inspecting the orchestrator constructor so we can inject the optional arguments if they exist
constructor_arg_names = [
param.name for param in inspect.signature(orchestrator_class.__init__).parameters.values()
]
# Building a map of complex top-level objects that belong outside the scenario
complex_args = {
"objective_target": objective_target,
"adversarial_chat": adversarial_chat,
"prompt_converters": prompt_converters,
"scoring_target": scoring_target,
"objective_scorer": objective_scorer,
}
# Disallowing scenario-level overrides for these complex args
for key in complex_args:
if key in scenario_args:
raise ValueError(f"{key} must be configured at the top-level of the config, not inside a scenario.")
# If the orchestrator constructor expects any of these, inject them
for key, value in complex_args.items():
if key in constructor_arg_names and value is not None:
scenario_args[key] = value
# And the instantiation of the orchestrator
try:
return orchestrator_class(**scenario_args)
except Exception as ex:
raise ValueError(f"Failed to instantiate scenario '{self.scenario_type}': {ex}") from ex
class TargetConfig(BaseModel):
"""
Configuration for a prompt target (e.g. OpenAIChatTarget).
"""
class_name: str = Field(..., alias="type", description="Prompt target class name (e.g. 'OpenAIChatTarget').")
def create_instance(self) -> Any:
"""
Dynamically instantiate the underlying target class.
"""
target_class = load_class(
module_name="pyrit.prompt_target", class_name=self.class_name, error_context="TargetConfig"
)
init_kwargs = self.model_dump(exclude={"class_name"})
return target_class(**init_kwargs)
class ObjectiveScorerConfig(BaseModel):
"""
Configuration for an objective scorer
"""
type: str = Field(..., description="Scorer class (e.g. 'SelfAskRefusalScorer').")
def create_scorer(self, scoring_target_obj: Optional[Any]) -> Any:
"""
Load and instantiate the scorer class.
"""
scorer_class = load_class(module_name="pyrit.score", class_name=self.type, error_context="objective_scorer")
init_kwargs = self.model_dump(exclude={"type"})
signature = inspect.signature(scorer_class.__init__)
chat_target_key: str = "chat_target"
if chat_target_key in signature.parameters:
if scoring_target_obj is None:
raise KeyError(
"Scorer requires a scoring_target to be defined. "
"Alternatively, the adversarial_target can be used for scoring purposes, "
"but none was provided."
)
init_kwargs[chat_target_key] = scoring_target_obj
return scorer_class(**init_kwargs)
class ScoringConfig(BaseModel):
"""
Configuration for the scoring setup, including optional
override of the default adversarial chat with a 'scoring_target'
and/or an 'objective_scorer'.
"""
scoring_target: Optional[TargetConfig] = Field(
None, description="If provided, use this target for scoring instead of 'adversarial_chat'."
)
objective_scorer: Optional[ObjectiveScorerConfig] = Field(
None, description="Details for the objective scorer, if any."
)
def create_objective_scorer(self, scoring_target_obj: Optional[Any]) -> Optional[Any]:
# If the user did not provide an objective_scorer config block (meaning the YAML lacks that section),
# we simply return None – no scorer to instantiate.
if not self.objective_scorer:
return None
return self.objective_scorer.create_scorer(scoring_target_obj=scoring_target_obj)
class ConverterConfig(BaseModel):
"""
Configuration for a single prompt converter, e.g. type: "Base64Converter"
"""
class_name: str = Field(..., alias="type", description="The prompt converter class name (e.g. 'Base64Converter').")
def create_instance(self) -> Any:
"""
Dynamically load and instantiate the converter class
"""
converter_class = load_class(
module_name="pyrit.prompt_converter", class_name=self.class_name, error_context="prompt_converter"
)
init_kwargs = self.model_dump(exclude={"class_name"})
return converter_class(**init_kwargs)
class ExecutionSettings(BaseModel):
"""
Configuration for how the scanner is executed (e.g. locally or via AzureML).
"""
type: SupportedExecutionTypes = Field(
"local", description=f"Execution environment. Supported values: {list(get_args(SupportedExecutionTypes))}"
)
parallel_nodes: Optional[int] = Field(None, description="How many scenarios to run in parallel.")
class ScannerConfig(BaseModel):
"""
Top-level configuration for the entire scanner.
"""
datasets: List[str] = Field(..., description="List of dataset YAML paths to load seed prompts from.")
scenarios: List[ScenarioConfig] = Field(..., description="List of scenario orchestrators to execute.")
objective_target: TargetConfig = Field(..., description="Configuration of the main (objective) chat target.")
adversarial_chat: Optional[TargetConfig] = Field(
None, description="Configuration of the adversarial chat target (if any)."
)
scoring: Optional[ScoringConfig] = Field(None, description="Scoring configuration (if any).")
converters: Optional[List[ConverterConfig]] = Field(None, description="List of prompt converters to apply.")
execution_settings: ExecutionSettings = Field(
default_factory=lambda: ExecutionSettings.model_validate({}),
description="Settings for how the scan is executed.",
)
database: DatabaseConfig = Field(
...,
description="Database configuration for storing memory or results, including memory_labels.",
)
@field_validator("objective_target", mode="before")
def check_objective_target_is_dict(cls, value):
"""
Ensure the user actually provides a dict.
Pydantic will run this validator before it attempts to parse the value into the TargetConfig model
"""
if not isinstance(value, dict):
raise ValueError(
"Field 'objective_target' must be a dictionary.\n"
"Example:\n"
" objective_target:\n"
" type: OpenAIChatTarget"
)
return value
@model_validator(mode="after")
def fill_scoring_target(self) -> "ScannerConfig":
"""
If config.scoring exists but doesn't explicitly define a scoring_target,
default it to the adversarial_chat
"""
if self.scoring:
if self.scoring.scoring_target is None and self.adversarial_chat is not None:
self.scoring.scoring_target = self.adversarial_chat
return self
@classmethod
def from_yaml(cls, path: str) -> "ScannerConfig":
"""
Loads configuration from a YAML file and validates it using Pydantic.
"""
with open(path, "r", encoding="utf-8") as f:
raw_dict = yaml.safe_load(f)
return cls(**raw_dict)
def create_objective_scorer(self) -> Optional[Any]:
"""
if there's an objective scorer configured,
instantiate it using 'scoring_target' (which might be adversarial_chat).
"""
if not self.scoring:
return None
scoring_target = None
if self.scoring.scoring_target:
scoring_target = self.scoring.scoring_target.create_instance()
return self.scoring.create_objective_scorer(scoring_target_obj=scoring_target)
def create_prompt_converters(self) -> List[PromptConverter]:
"""
Instantiates each converter defined in 'self.converters' (if any).
Returns a list of converter objects.
"""
if not self.converters:
return []
instances = []
for converter_cfg in self.converters:
instances.append(converter_cfg.create_instance())
return instances
def create_orchestrators(self, prompt_converters: Optional[List[PromptConverter]] = None) -> List[Any]:
"""
Helper method to instantiate all orchestrators from the scenario configs,
injecting objective_target, adversarial_chat, scoring_target, objective_scorer, etc.
"""
# Instantiate the top-level targets
objective_target_obj = self.objective_target.create_instance()
adversarial_chat_obj = None
if self.adversarial_chat:
adversarial_chat_obj = self.adversarial_chat.create_instance()
# If there is a scoring_target or an objective_scorer:
scoring_target_obj = None
objective_scorer_obj = None
if self.scoring:
# fill_scoring_target might have already assigned it to self.scoring.scoring_target
if self.scoring.scoring_target:
scoring_target_obj = self.scoring.scoring_target.create_instance()
# create the actual scorer
objective_scorer_obj = self.scoring.create_objective_scorer(scoring_target_obj=scoring_target_obj)
# Now each scenario can create its orchestrator
orchestrators = []
for scenario in self.scenarios:
orch = scenario.create_orchestrator(
objective_target=objective_target_obj,
adversarial_chat=adversarial_chat_obj,
prompt_converters=prompt_converters,
scoring_target=scoring_target_obj,
objective_scorer=objective_scorer_obj,
)
orchestrators.append(orch)
return orchestrators