in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def get_workers(params, eval=False):
tokenizers = []
for i in range(len(params.tokenizer_paths)):
tokenizer = AutoTokenizer.from_pretrained(
params.tokenizer_paths[i], token=params.token, trust_remote_code=False, **params.tokenizer_kwargs[i]
)
if "oasst-sft-6-llama-30b" in params.tokenizer_paths[i]:
tokenizer.bos_token_id = 1
tokenizer.unk_token_id = 0
if "guanaco" in params.tokenizer_paths[i]:
tokenizer.eos_token_id = 2
tokenizer.unk_token_id = 0
if "llama-2" in params.tokenizer_paths[i]:
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "left"
if "falcon" in params.tokenizer_paths[i]:
tokenizer.padding_side = "left"
if "Phi-3-mini-4k-instruct" in params.tokenizer_paths[i]:
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 32000
tokenizer.unk_token_id = 0
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
tokenizers.append(tokenizer)
print(f"Loaded {len(tokenizers)} tokenizers")
raw_conv_templates = []
for template in params.conversation_templates:
if template in ["llama-2", "mistral", "llama-3-8b", "vicuna"]:
raw_conv_templates.append(get_conversation_template(template)),
elif template in ["phi-3-mini"]:
conv_template = Conversation(
name="phi-3-mini",
system_template="<|system|>\n{system_message}",
system_message="",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.CHATML,
sep="<|end|>",
stop_token_ids=[32000, 32001, 32007],
)
raw_conv_templates.append(conv_template)
else:
raise ValueError("Conversation template not recognized")
conv_templates = []
for conv in raw_conv_templates:
if conv.name == "zero_shot":
conv.roles = tuple(["### " + r for r in conv.roles])
conv.sep = "\n"
elif conv.name == "llama-2":
conv.sep2 = conv.sep2.strip()
conv_templates.append(conv)
print(f"Loaded {len(conv_templates)} conversation templates")
workers = [
ModelWorker(
params.model_paths[i],
params.token,
params.model_kwargs[i],
tokenizers[i],
conv_templates[i],
params.devices[i],
)
for i in range(len(params.model_paths))
]
if not eval:
for worker in workers:
worker.start()
num_train_models = getattr(params, "num_train_models", len(workers))
print("Loaded {} train models".format(num_train_models))
print("Loaded {} test models".format(len(workers) - num_train_models))
return workers[:num_train_models], workers[num_train_models:]