def get_filtered_cands()

in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]


    def get_filtered_cands(self, worker_index, control_cand, filter_cand=True, curr_control=None):
        cands, count = [], 0
        worker = self.workers[worker_index]

        print("Masking out of range token_id.")
        vocab_size = worker.tokenizer.vocab_size
        control_cand[control_cand > vocab_size] = worker.tokenizer("!").input_ids[0]

        for i in range(control_cand.shape[0]):
            decoded_str = worker.tokenizer.decode(
                control_cand[i], skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            if filter_cand:
                if decoded_str != curr_control and len(
                    worker.tokenizer(decoded_str, add_special_tokens=False).input_ids
                ) == len(control_cand[i]):
                    cands.append(decoded_str)
                else:
                    count += 1
            else:
                cands.append(decoded_str)

        if filter_cand:
            cands = cands + [cands[-1]] * (len(control_cand) - len(cands))
            # print(f"Warning: {round(count / len(control_cand), 2)} control candidates were not valid")
        return cands