in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def test(self, workers, prompts, include_loss=False):
for j, worker in enumerate(workers):
worker(prompts[j], "test", worker.model)
model_tests = np.array([worker.results.get() for worker in workers])
model_tests_jb = model_tests[..., 0].tolist()
model_tests_mb = model_tests[..., 1].tolist()
model_tests_loss = []
if include_loss:
for j, worker in enumerate(workers):
worker(prompts[j], "test_loss", worker.model)
model_tests_loss = [worker.results.get() for worker in workers]
return model_tests_jb, model_tests_mb, model_tests_loss