def step()

in online_attacks/utils/optimizer/extragradient.py [0:0]


    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        if self.extrapolation_flag is False:
            for group in self.param_groups:
                group["params_copy"] = copy.deepcopy(group["params"])
            self.optimizer.step()
            self.extrapolation_flag = True

        else:
            for group in self.param_groups:
                for p, p_copy in zip(group["params"], group["params_copy"]):
                    p.data = p_copy.data
            self.optimizer.step()
            self.extrapolation_flag = False

        return loss