def solve()

in automl21/scs_neural/solver/neural_scs_batched.py [0:0]


    def solve(self, multi_instance, train=True, max_iters=5000, use_scaling=True, scale=1,
              rho_x=1e-3, alpha=1.5, track_metrics=False, **kwargs):
        # create a set of sets for accessing all past iterates
        diffs_u, objectives, all_residuals, losses = [], [], [], []
        seq_tau, train_diff_u = [], []

        QplusI_lu, QplusI_t_lu = self._obtain_QplusI_matrices(multi_instance, rho_x)
        m, n, num_instances = multi_instance.get_sizes()
        total_size = m + n + 1
        u, v = self._initialize_iterates(m, n, num_instances)

        context = None
        init_diff_u, init_fp_u, init_scaled_u = self._compute_init_diff_u(
            u, v, total_size, multi_instance, rho_x, alpha
        )

        if isinstance(self.accel, NeuralRec):
            if self.model_cfg.learn_init_iterate:
                diffs_u.append(init_diff_u.norm(dim=1))

            if self.model_cfg.learn_init_iterate or \
                    self.model_cfg.learn_init_hidden:
                context = self._construct_context(multi_instance)
                context = context.to(self.model_cfg.device)

        u, tau, scaled_u = self._unscale_before_model(u)
        u = u.to(self.model_cfg.device)
        u, hidden = self.accel.init_instance(
            init_x=u, context=context)
        u = u.to(self.device)
        u_orig = self._rescale_after_model(u, tau)

        for j in range(max_iters):
            if j > 0:
                u_upd, tau, scaled_u = self._unscale_before_model(u)
                fp_u_upd, fp_tau, scaled_fp_u = self._unscale_before_model(fp_u)
                if scaled_u and scaled_fp_u:
                    u_upd, fp_u_upd = u_upd.to(self.model_cfg.device), fp_u_upd.to(self.model_cfg.device)
                    u_upd, _, hidden = self.accel.update(
                        fx=fp_u_upd, x=u_upd, hidden=hidden)
                    u_upd = u_upd.to(self.device)
                else:
                    u_upd, tau = fp_u_upd, fp_tau
                u_orig = self._rescale_after_model(u_upd, tau)

            u, v = self._scale_iterates(u_orig, v, total_size)
            fp_u, fp_v = self._fixed_point_iteration(
                QplusI_lu, QplusI_t_lu, u, v, multi_instance, rho_x, alpha
            )
            v = fp_v

            diff_u = fp_u - u
            with torch.no_grad():
                curr_train_diff_u = self._compute_scaled_loss(u_orig, fp_u)
                train_diff_u.append(curr_train_diff_u)
            if j > 0:
                curr_loss = self._compute_scaled_loss(u_orig, fp_u)
                if curr_loss is not None:
                    losses.append(curr_loss)
            if track_metrics:
                with torch.no_grad():
                    res_p, res_d = self._compute_residuals(fp_u, fp_v,
                                                           multi_instance)
                    all_residuals.append([res_p.norm(dim=1), res_d.norm(dim=1)])
                    diffs_u.append(diff_u.norm(dim=1))
                    seq_tau.append(fp_u[:, -1])
                    objectives.append(self._get_objectives(fp_u, fp_v,
                                                           multi_instance))

        # check if the solution is feasible
        all_tau = fp_u[:, -1]
        # if any solution is infeasible, zero it out from the loss
        if (all_tau <= 0).any():
            curr_loss = self._compute_scaled_loss(u_orig, fp_u, include_bad_tau=True)
            if curr_loss is not None:
                losses.append(curr_loss)
        if (all_tau <= 0).all() or len(losses) == 0:
            return [], [], [], False

        # convert iterates to solution, and track all solutions
        if train and self.regularize > 0.:
            res_p, res_d = self._compute_residuals_sparse_for_backprop(fp_u, fp_v, multi_instance)
        else:
            if not track_metrics:
                with torch.no_grad():
                    res_p, res_d = self._compute_residuals(fp_u, fp_v, multi_instance)
        soln, diffu_counts = self._extract_solution(
            fp_u, fp_v, multi_instance, (res_p, res_d), losses, train_diff_u
        )
        metrics = {}
        if track_metrics:
            metrics = {"residuals": all_residuals, "objectives": objectives,
                       "diffs_u": diffs_u, "all_tau": seq_tau}
        # do this to be consistent with SCS
        soln, metrics = self._convert_to_sequential_list(soln, metrics, num_instances)
        if train:
            return soln, metrics, diffu_counts, True
        else:
            return soln, metrics