pyrit/auxiliary_attacks/gcg/experiments/train.py (168 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
from typing import Union
import mlflow
import numpy as np
import torch.multiprocessing as mp
from ml_collections import config_dict
import pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack as attack_lib
from pyrit.auxiliary_attacks.gcg.attack.base.attack_manager import (
IndividualPromptAttack,
ProgressiveMultiPromptAttack,
get_goals_and_targets,
get_workers,
)
from pyrit.auxiliary_attacks.gcg.experiments.log import (
log_gpu_memory,
log_params,
log_train_goals,
)
class GreedyCoordinateGradientAdversarialSuffixGenerator:
def __init__(self):
if mp.get_start_method(allow_none=True) != "spawn":
mp.set_start_method("spawn")
def generate_suffix(
self,
*,
token: str = "",
tokenizer_paths: list = [],
model_name: str = "",
model_paths: list = [],
conversation_templates: list = [],
result_prefix: str = "",
train_data: str = "",
control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
n_train_data: int = 50,
n_steps: int = 500,
test_steps: int = 50,
batch_size: int = 512,
transfer: bool = False,
target_weight: float = 1.0,
control_weight: float = 0.0,
progressive_goals: bool = False,
progressive_models: bool = False,
anneal: bool = False,
incr_control: bool = False,
stop_on_success: bool = False,
verbose: bool = True,
allow_non_ascii: bool = False,
num_train_models: int = 1,
devices: list = ["cuda:0"],
model_kwargs: list = [{"low_cpu_mem_usage": True, "use_cache": False}],
tokenizer_kwargs: list = [{"use_fast": False}],
n_test_data: int = 0,
test_data: str = "",
lr: float = 0.01,
topk: int = 256,
temp: int = 1,
filter_cand: bool = True,
gbda_deterministic: bool = True,
logfile: str = "",
random_seed: int = 42,
):
params = config_dict.ConfigDict()
params.result_prefix = result_prefix
params.train_data = train_data
params.control_init = control_init
params.n_train_data = n_train_data
params.n_steps = n_steps
params.test_steps = test_steps
params.batch_size = batch_size
params.transfer = transfer
params.target_weight = target_weight
params.control_weight = control_weight
params.progressive_goals = progressive_goals
params.progressive_models = progressive_models
params.anneal = anneal
params.incr_control = incr_control
params.stop_on_success = stop_on_success
params.verbose = verbose
params.allow_non_ascii = allow_non_ascii
params.num_train_models = num_train_models
params.tokenizer_paths = tokenizer_paths
params.tokenizer_kwargs = tokenizer_kwargs
params.model_name = model_name
params.model_paths = model_paths
params.model_kwargs = model_kwargs
params.conversation_templates = conversation_templates
params.devices = devices
params.n_test_data = n_test_data
params.test_data = test_data
params.lr = lr
params.topk = topk
params.temp = temp
params.filter_cand = filter_cand
params.gbda_deterministic = gbda_deterministic
params.token = token
params.logfile = logfile
params.random_seed = random_seed
print(params)
# Start mlflow logging
mlflow.start_run()
log_gpu_memory(step=0)
log_params(params=params)
train_goals, train_targets, test_goals, test_targets = get_goals_and_targets(params)
log_train_goals(train_goals=train_goals)
def process_fn(s):
return s.replace("Sure, h", "H")
def process_fn2(s):
return s.replace("Sure, here is", "Sure, here's")
train_targets = [process_fn(t) if np.random.random() < 0.5 else process_fn2(t) for t in train_targets]
test_targets = [process_fn(t) if np.random.random() < 0.5 else process_fn2(t) for t in test_targets]
workers, test_workers = get_workers(params)
managers = {
"AP": attack_lib.GCGAttackPrompt,
"PM": attack_lib.GCGPromptManager,
"MPA": attack_lib.GCGMultiPromptAttack,
}
timestamp = time.strftime("%Y%m%d-%H%M%S")
attack: Union[ProgressiveMultiPromptAttack, IndividualPromptAttack]
if params.transfer:
attack = ProgressiveMultiPromptAttack(
train_goals,
train_targets,
workers,
progressive_models=params.progressive_models,
progressive_goals=params.progressive_goals,
control_init=params.control_init,
logfile=f"{params.result_prefix}_{timestamp}.json",
managers=managers,
test_goals=test_goals,
test_targets=test_targets,
test_workers=test_workers,
mpa_deterministic=params.gbda_deterministic,
mpa_lr=params.lr,
mpa_batch_size=params.batch_size,
mpa_n_steps=params.n_steps,
)
else:
attack = IndividualPromptAttack(
train_goals,
train_targets,
workers,
control_init=params.control_init,
logfile=f"{params.result_prefix}_{timestamp}.json",
managers=managers,
test_goals=getattr(params, "test_goals", []),
test_targets=getattr(params, "test_targets", []),
test_workers=test_workers,
mpa_deterministic=params.gbda_deterministic,
mpa_lr=params.lr,
mpa_batch_size=params.batch_size,
mpa_n_steps=params.n_steps,
)
attack.run(
n_steps=params.n_steps,
batch_size=params.batch_size,
topk=params.topk,
temp=params.temp,
target_weight=params.target_weight,
control_weight=params.control_weight,
test_steps=getattr(params, "test_steps", 1),
anneal=params.anneal,
incr_control=params.incr_control,
stop_on_success=params.stop_on_success,
verbose=params.verbose,
filter_cand=params.filter_cand,
allow_non_ascii=params.allow_non_ascii,
)
for worker in workers + test_workers:
worker.stop()