in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def run(model, tasks, results):
while True:
task = tasks.get()
if task is None:
break
ob, fn, args, kwargs = task
if fn == "grad":
with torch.enable_grad():
results.put(ob.grad(*args, **kwargs))
else:
with torch.no_grad():
if fn == "logits":
results.put(ob.logits(*args, **kwargs))
elif fn == "contrast_logits":
results.put(ob.contrast_logits(*args, **kwargs))
elif fn == "test":
results.put(ob.test(*args, **kwargs))
elif fn == "test_loss":
results.put(ob.test_loss(*args, **kwargs))
else:
results.put(fn(*args, **kwargs))
tasks.task_done()