in online_attacks/utils/optimizer/extragradient.py [0:0]
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
if self.extrapolation_flag is False:
for group in self.param_groups:
group["params_copy"] = copy.deepcopy(group["params"])
self.optimizer.step()
self.extrapolation_flag = True
else:
for group in self.param_groups:
for p, p_copy in zip(group["params"], group["params_copy"]):
p.data = p_copy.data
self.optimizer.step()
self.extrapolation_flag = False
return loss