in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def logits(self, model, test_controls=None, return_ids=False):
pad_tok = -1
if test_controls is None:
test_controls = self.control_toks
if isinstance(test_controls, torch.Tensor):
if len(test_controls.shape) == 1:
test_controls = test_controls.unsqueeze(0)
test_ids = test_controls.to(model.device)
elif not isinstance(test_controls, list):
test_controls = [test_controls]
elif isinstance(test_controls[0], str):
max_len = self._control_slice.stop - self._control_slice.start
test_ids = [
torch.tensor(self.tokenizer(control, add_special_tokens=False).input_ids[:max_len], device=model.device)
for control in test_controls
]
pad_tok = 0
while pad_tok in self.input_ids or any([pad_tok in ids for ids in test_ids]):
pad_tok += 1
nested_ids = torch.nested.nested_tensor(test_ids)
test_ids = torch.nested.to_padded_tensor(nested_ids, pad_tok, (len(test_ids), max_len))
else:
raise ValueError(
f"test_controls must be a list of strings or a tensor of token ids, got {type(test_controls)}"
)
if not (test_ids[0].shape[0] == self._control_slice.stop - self._control_slice.start):
raise ValueError(
(
f"test_controls must have shape "
f"(n, {self._control_slice.stop - self._control_slice.start}), "
f"got {test_ids.shape}"
)
)
locs = (
torch.arange(self._control_slice.start, self._control_slice.stop)
.repeat(test_ids.shape[0], 1)
.to(model.device)
)
ids = torch.scatter(
self.input_ids.unsqueeze(0).repeat(test_ids.shape[0], 1).to(model.device), 1, locs, test_ids
)
if pad_tok >= 0:
attn_mask = (ids != pad_tok).type(ids.dtype)
else:
attn_mask = None
if return_ids:
del locs, test_ids
gc.collect()
return model(input_ids=ids, attention_mask=attn_mask).logits, ids
else:
del locs, test_ids
logits = model(input_ids=ids, attention_mask=attn_mask).logits
del ids
gc.collect()
return logits