def step()

in online_attacks/utils/optimizer/sls/sls.py [0:0]


    def step(self, closure):
        # deterministic closure
        seed = time.time()

        def closure_deterministic():
            with ut.random_seed_torch(int(seed)):
                return closure()

        batch_step_size = self.state["step_size"]

        # get loss and compute gradients
        loss = closure_deterministic()
        loss.backward()

        # increment # forward-backward calls
        self.state["n_forwards"] += 1
        self.state["n_backwards"] += 1

        # loop over parameter groups
        for group in self.param_groups:
            params = group["params"]

            # save the current parameters:
            params_current = copy.deepcopy(params)
            grad_current = ut.get_grad_list(params)

            grad_norm = ut.compute_grad_norm(grad_current)

            step_size = ut.reset_step(
                step_size=batch_step_size,
                n_batches_per_epoch=group["n_batches_per_epoch"],
                gamma=group["gamma"],
                reset_option=group["reset_option"],
                init_step_size=group["init_step_size"],
            )

            # only do the check if the gradient norm is big enough
            with torch.no_grad():
                if grad_norm >= 1e-8:
                    # check if condition is satisfied
                    found = 0
                    step_size_old = step_size

                    for e in range(100):
                        # try a prospective step
                        ut.try_sgd_update(
                            params, step_size, params_current, grad_current
                        )

                        # compute the loss at the next step; no need to compute gradients.
                        loss_next = closure_deterministic()
                        self.state["n_forwards"] += 1

                        # =================================================
                        # Line search
                        if group["line_search_fn"] == "armijo":
                            armijo_results = ut.check_armijo_conditions(
                                step_size=step_size,
                                step_size_old=step_size_old,
                                loss=loss,
                                grad_norm=grad_norm,
                                loss_next=loss_next,
                                c=group["c"],
                                beta_b=group["beta_b"],
                            )
                            found, step_size, step_size_old = armijo_results
                            if found == 1:
                                break

                        elif group["line_search_fn"] == "goldstein":
                            goldstein_results = ut.check_goldstein_conditions(
                                step_size=step_size,
                                loss=loss,
                                grad_norm=grad_norm,
                                loss_next=loss_next,
                                c=group["c"],
                                beta_b=group["beta_b"],
                                beta_f=group["beta_f"],
                                bound_step_size=group["bound_step_size"],
                                eta_max=group["eta_max"],
                            )

                            found = goldstein_results["found"]
                            step_size = goldstein_results["step_size"]

                            if found == 3:
                                break

                    # if line search exceeds max_epochs
                    if found == 0:
                        ut.try_sgd_update(params, 1e-6, params_current, grad_current)

            # save the new step-size
            self.state["step_size"] = step_size
            self.state["step"] += 1

        return loss