in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def _update_ids(self):
self.conv_template.append_message(self.conv_template.roles[0], f"{self.goal} {self.control}")
self.conv_template.append_message(self.conv_template.roles[1], f"{self.target}")
prompt = self.conv_template.get_prompt()
encoding = self.tokenizer(prompt)
toks = encoding.input_ids
if self.conv_template.name == "llama-2" or self.conv_template.name == "llama-3":
self.conv_template.messages = []
self.conv_template.append_message(self.conv_template.roles[0], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._user_role_slice = slice(None, len(toks))
self.conv_template.update_last_message(f"{self.goal}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._goal_slice = slice(self._user_role_slice.stop, max(self._user_role_slice.stop, len(toks)))
separator = " " if self.goal else ""
self.conv_template.update_last_message(f"{self.goal}{separator}{self.control}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._control_slice = slice(self._goal_slice.stop, len(toks))
self.conv_template.append_message(self.conv_template.roles[1], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._assistant_role_slice = slice(self._control_slice.stop, len(toks))
self.conv_template.update_last_message(f"{self.target}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._target_slice = slice(self._assistant_role_slice.stop, len(toks) - 2)
self._loss_slice = slice(self._assistant_role_slice.stop - 1, len(toks) - 3)
else:
python_tokenizer = False or self.conv_template.name == "oasst_pythia"
try:
encoding.char_to_token(len(prompt) - 1)
except Exception:
python_tokenizer = True
if python_tokenizer:
# This is specific to the vicuna and pythia tokenizer and conversation prompt.
# It will not work with other tokenizers or prompts.
self.conv_template.messages = []
self.conv_template.append_message(self.conv_template.roles[0], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._user_role_slice = slice(None, len(toks))
self.conv_template.update_last_message(f"{self.goal}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._goal_slice = slice(self._user_role_slice.stop, max(self._user_role_slice.stop, len(toks) - 1))
separator = " " if self.goal else ""
self.conv_template.update_last_message(f"{self.goal}{separator}{self.control}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._control_slice = slice(self._goal_slice.stop, len(toks) - 1)
self.conv_template.append_message(self.conv_template.roles[1], None)
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._assistant_role_slice = slice(self._control_slice.stop, len(toks))
self.conv_template.update_last_message(f"{self.target}")
toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
self._target_slice = slice(self._assistant_role_slice.stop, len(toks) - 1)
self._loss_slice = slice(self._assistant_role_slice.stop - 1, len(toks) - 2)
else:
self._system_slice = slice(None, encoding.char_to_token(len(self.conv_template.system)))
self._user_role_slice = slice(
encoding.char_to_token(prompt.find(self.conv_template.roles[0])),
encoding.char_to_token(
prompt.find(self.conv_template.roles[0]) + len(self.conv_template.roles[0]) + 1
),
)
self._goal_slice = slice(
encoding.char_to_token(prompt.find(self.goal)),
encoding.char_to_token(prompt.find(self.goal) + len(self.goal)),
)
self._control_slice = slice(
encoding.char_to_token(prompt.find(self.control)),
encoding.char_to_token(prompt.find(self.control) + len(self.control)),
)
self._assistant_role_slice = slice(
encoding.char_to_token(prompt.find(self.conv_template.roles[1])),
encoding.char_to_token(
prompt.find(self.conv_template.roles[1]) + len(self.conv_template.roles[1]) + 1
),
)
self._target_slice = slice(
encoding.char_to_token(prompt.find(self.target)),
encoding.char_to_token(prompt.find(self.target) + len(self.target)),
)
self._loss_slice = slice(
encoding.char_to_token(prompt.find(self.target)) - 1,
encoding.char_to_token(prompt.find(self.target) + len(self.target)) - 1,
)
self.input_ids = torch.tensor(toks[: self._target_slice.stop], device="cpu")
self.conv_template.messages = []