in automl21/scs_neural/solver/neural_scs_batched.py [0:0]
def solve(self, multi_instance, train=True, max_iters=5000, use_scaling=True, scale=1,
rho_x=1e-3, alpha=1.5, track_metrics=False, **kwargs):
# create a set of sets for accessing all past iterates
diffs_u, objectives, all_residuals, losses = [], [], [], []
seq_tau, train_diff_u = [], []
QplusI_lu, QplusI_t_lu = self._obtain_QplusI_matrices(multi_instance, rho_x)
m, n, num_instances = multi_instance.get_sizes()
total_size = m + n + 1
u, v = self._initialize_iterates(m, n, num_instances)
context = None
init_diff_u, init_fp_u, init_scaled_u = self._compute_init_diff_u(
u, v, total_size, multi_instance, rho_x, alpha
)
if isinstance(self.accel, NeuralRec):
if self.model_cfg.learn_init_iterate:
diffs_u.append(init_diff_u.norm(dim=1))
if self.model_cfg.learn_init_iterate or \
self.model_cfg.learn_init_hidden:
context = self._construct_context(multi_instance)
context = context.to(self.model_cfg.device)
u, tau, scaled_u = self._unscale_before_model(u)
u = u.to(self.model_cfg.device)
u, hidden = self.accel.init_instance(
init_x=u, context=context)
u = u.to(self.device)
u_orig = self._rescale_after_model(u, tau)
for j in range(max_iters):
if j > 0:
u_upd, tau, scaled_u = self._unscale_before_model(u)
fp_u_upd, fp_tau, scaled_fp_u = self._unscale_before_model(fp_u)
if scaled_u and scaled_fp_u:
u_upd, fp_u_upd = u_upd.to(self.model_cfg.device), fp_u_upd.to(self.model_cfg.device)
u_upd, _, hidden = self.accel.update(
fx=fp_u_upd, x=u_upd, hidden=hidden)
u_upd = u_upd.to(self.device)
else:
u_upd, tau = fp_u_upd, fp_tau
u_orig = self._rescale_after_model(u_upd, tau)
u, v = self._scale_iterates(u_orig, v, total_size)
fp_u, fp_v = self._fixed_point_iteration(
QplusI_lu, QplusI_t_lu, u, v, multi_instance, rho_x, alpha
)
v = fp_v
diff_u = fp_u - u
with torch.no_grad():
curr_train_diff_u = self._compute_scaled_loss(u_orig, fp_u)
train_diff_u.append(curr_train_diff_u)
if j > 0:
curr_loss = self._compute_scaled_loss(u_orig, fp_u)
if curr_loss is not None:
losses.append(curr_loss)
if track_metrics:
with torch.no_grad():
res_p, res_d = self._compute_residuals(fp_u, fp_v,
multi_instance)
all_residuals.append([res_p.norm(dim=1), res_d.norm(dim=1)])
diffs_u.append(diff_u.norm(dim=1))
seq_tau.append(fp_u[:, -1])
objectives.append(self._get_objectives(fp_u, fp_v,
multi_instance))
# check if the solution is feasible
all_tau = fp_u[:, -1]
# if any solution is infeasible, zero it out from the loss
if (all_tau <= 0).any():
curr_loss = self._compute_scaled_loss(u_orig, fp_u, include_bad_tau=True)
if curr_loss is not None:
losses.append(curr_loss)
if (all_tau <= 0).all() or len(losses) == 0:
return [], [], [], False
# convert iterates to solution, and track all solutions
if train and self.regularize > 0.:
res_p, res_d = self._compute_residuals_sparse_for_backprop(fp_u, fp_v, multi_instance)
else:
if not track_metrics:
with torch.no_grad():
res_p, res_d = self._compute_residuals(fp_u, fp_v, multi_instance)
soln, diffu_counts = self._extract_solution(
fp_u, fp_v, multi_instance, (res_p, res_d), losses, train_diff_u
)
metrics = {}
if track_metrics:
metrics = {"residuals": all_residuals, "objectives": objectives,
"diffs_u": diffs_u, "all_tau": seq_tau}
# do this to be consistent with SCS
soln, metrics = self._convert_to_sequential_list(soln, metrics, num_instances)
if train:
return soln, metrics, diffu_counts, True
else:
return soln, metrics