in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def get_filtered_cands(self, worker_index, control_cand, filter_cand=True, curr_control=None):
cands, count = [], 0
worker = self.workers[worker_index]
print("Masking out of range token_id.")
vocab_size = worker.tokenizer.vocab_size
control_cand[control_cand > vocab_size] = worker.tokenizer("!").input_ids[0]
for i in range(control_cand.shape[0]):
decoded_str = worker.tokenizer.decode(
control_cand[i], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
if filter_cand:
if decoded_str != curr_control and len(
worker.tokenizer(decoded_str, add_special_tokens=False).input_ids
) == len(control_cand[i]):
cands.append(decoded_str)
else:
count += 1
else:
cands.append(decoded_str)
if filter_cand:
cands = cands + [cands[-1]] * (len(control_cand) - len(cands))
# print(f"Warning: {round(count / len(control_cand), 2)} control candidates were not valid")
return cands