def run()

in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]


    def run(self, steps, controls, batch_size, max_new_len=60, verbose=True):

        model, tokenizer = self.workers[0].model, self.workers[0].tokenizer
        tokenizer.padding_side = "left"

        if self.logfile is not None:
            with open(self.logfile, "r") as f:
                log = json.load(f)

            log["params"]["num_tests"] = len(controls)

            with open(self.logfile, "w") as f:
                json.dump(log, f, indent=4)

        total_jb, total_em, total_outputs = [], [], []
        test_total_jb, test_total_em, test_total_outputs = [], [], []
        prev_control = "haha"
        for step, control in enumerate(controls):
            for mode, goals, targets in zip(
                *[("Train", "Test"), (self.goals, self.test_goals), (self.targets, self.test_targets)]
            ):
                if control != prev_control and len(goals) > 0:
                    attack = self.managers["MPA"](
                        goals,
                        targets,
                        self.workers,
                        control,
                        self.test_prefixes,
                        self.logfile,
                        self.managers,
                        **self.mpa_kewargs,
                    )
                    all_inputs = [p.eval_str for p in attack.prompts[0]._prompts]
                    max_new_tokens = [p.test_new_toks for p in attack.prompts[0]._prompts]
                    targets = [p.target for p in attack.prompts[0]._prompts]
                    all_outputs = []
                    # iterate each batch of inputs
                    for i in range(len(all_inputs) // batch_size + 1):
                        batch = all_inputs[i * batch_size : (i + 1) * batch_size]
                        batch_max_new = max_new_tokens[i * batch_size : (i + 1) * batch_size]

                        batch_inputs = tokenizer(batch, padding=True, truncation=False, return_tensors="pt")

                        batch_input_ids = batch_inputs["input_ids"].to(model.device)
                        batch_attention_mask = batch_inputs["attention_mask"].to(model.device)
                        # position_ids = batch_attention_mask.long().cumsum(-1) - 1
                        # position_ids.masked_fill_(batch_attention_mask == 0, 1)
                        outputs = model.generate(
                            batch_input_ids,
                            attention_mask=batch_attention_mask,
                            max_new_tokens=max(max_new_len, max(batch_max_new)),
                        )
                        batch_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                        gen_start_idx = [
                            len(tokenizer.decode(batch_input_ids[i], skip_special_tokens=True))
                            for i in range(len(batch_input_ids))
                        ]
                        batch_outputs = [output[gen_start_idx[i] :] for i, output in enumerate(batch_outputs)]
                        all_outputs.extend(batch_outputs)

                        # clear cache
                        del batch_inputs, batch_input_ids, batch_attention_mask, outputs, batch_outputs
                        torch.cuda.empty_cache()

                    curr_jb, curr_em = [], []
                    for gen_str, target in zip(all_outputs, targets):
                        jailbroken = not any([prefix in gen_str for prefix in self.test_prefixes])
                        em = target in gen_str
                        curr_jb.append(jailbroken)
                        curr_em.append(em)

                if mode == "Train":
                    total_jb.append(curr_jb)
                    total_em.append(curr_em)
                    total_outputs.append(all_outputs)
                    # print(all_outputs)
                else:
                    test_total_jb.append(curr_jb)
                    test_total_em.append(curr_em)
                    test_total_outputs.append(all_outputs)

                if verbose:
                    print(
                        f"{mode} Step {step+1}/{len(controls)} | "
                        f"Jailbroken {sum(curr_jb)}/{len(all_outputs)} | "
                        f"EM {sum(curr_em)}/{len(all_outputs)}"
                    )

            prev_control = control

        return total_jb, total_em, test_total_jb, test_total_em, total_outputs, test_total_outputs