in online_attacks/utils/optimizer/sls/sls.py [0:0]
def step(self, closure):
# deterministic closure
seed = time.time()
def closure_deterministic():
with ut.random_seed_torch(int(seed)):
return closure()
batch_step_size = self.state["step_size"]
# get loss and compute gradients
loss = closure_deterministic()
loss.backward()
# increment # forward-backward calls
self.state["n_forwards"] += 1
self.state["n_backwards"] += 1
# loop over parameter groups
for group in self.param_groups:
params = group["params"]
# save the current parameters:
params_current = copy.deepcopy(params)
grad_current = ut.get_grad_list(params)
grad_norm = ut.compute_grad_norm(grad_current)
step_size = ut.reset_step(
step_size=batch_step_size,
n_batches_per_epoch=group["n_batches_per_epoch"],
gamma=group["gamma"],
reset_option=group["reset_option"],
init_step_size=group["init_step_size"],
)
# only do the check if the gradient norm is big enough
with torch.no_grad():
if grad_norm >= 1e-8:
# check if condition is satisfied
found = 0
step_size_old = step_size
for e in range(100):
# try a prospective step
ut.try_sgd_update(
params, step_size, params_current, grad_current
)
# compute the loss at the next step; no need to compute gradients.
loss_next = closure_deterministic()
self.state["n_forwards"] += 1
# =================================================
# Line search
if group["line_search_fn"] == "armijo":
armijo_results = ut.check_armijo_conditions(
step_size=step_size,
step_size_old=step_size_old,
loss=loss,
grad_norm=grad_norm,
loss_next=loss_next,
c=group["c"],
beta_b=group["beta_b"],
)
found, step_size, step_size_old = armijo_results
if found == 1:
break
elif group["line_search_fn"] == "goldstein":
goldstein_results = ut.check_goldstein_conditions(
step_size=step_size,
loss=loss,
grad_norm=grad_norm,
loss_next=loss_next,
c=group["c"],
beta_b=group["beta_b"],
beta_f=group["beta_f"],
bound_step_size=group["bound_step_size"],
eta_max=group["eta_max"],
)
found = goldstein_results["found"]
step_size = goldstein_results["step_size"]
if found == 3:
break
# if line search exceeds max_epochs
if found == 0:
ut.try_sgd_update(params, 1e-6, params_current, grad_current)
# save the new step-size
self.state["step_size"] = step_size
self.state["step"] += 1
return loss