def _compute_residuals_sparse_for_backprop()

in automl21/scs_neural/solver/neural_scs_batched.py [0:0]


    def _compute_residuals_sparse_for_backprop(self, u, v, multi_instance):
        """Compute residuals for backpropagation for use with regularize
            Ensure that nan never gets computed.
        """
        all_A, b, c = multi_instance.A, multi_instance.b, multi_instance.c
        m, n, num_instances = multi_instance.get_sizes()
        all_tau = u[:, -1]
        bad_tau = (all_tau <= 0)
        clean_tau = (all_tau > 0) * all_tau + 1 * bad_tau
        clean_tau = clean_tau.unsqueeze(1)
        x, y = u[:, :n]/clean_tau, u[:, n:n+m]/clean_tau
        s = v[:, n:n+m]/clean_tau

        # compute primal & dual residuals
        if hasattr(multi_instance, 'A_tensor'):
            A = multi_instance.A_tensor
        else:
            all_A_dense = np.stack([curr_A.toarray() for curr_A in all_A])
            A = torch.from_numpy(all_A_dense)
            multi_instance.A_tensor = A 

        x_expand, y_expand = x.unsqueeze(2), y.unsqueeze(2)
        prim_res = (A @ x_expand).squeeze() + s - b
        dual_res = (A.transpose(1, 2) @ y_expand).squeeze() + c

        orig_b, orig_c = b, c

        if multi_instance.scaled:
            D, E, sigma, rho = multi_instance.D, multi_instance.E, \
                multi_instance.sigma, multi_instance.rho
            prim_res = (D * prim_res) / sigma
            dual_res = (E * dual_res) / rho
            orig_b = multi_instance.orig_b
            orig_c = multi_instance.orig_c

        prim_res = prim_res / (1 + orig_b.norm(dim=1).unsqueeze(1))
        dual_res = dual_res / (1 + orig_c.norm(dim=1).unsqueeze(1))

        bad_tau_filler = bad_tau.unsqueeze(1)
        prim_res = prim_res * (~bad_tau_filler) + 0.0 * bad_tau_filler
        dual_res = dual_res * (~bad_tau_filler) + 0.0 * bad_tau_filler

        return prim_res, dual_res