in automl21/scs_neural/solver/neural_scs_batched.py [0:0]
def _compute_residuals_sparse_for_backprop(self, u, v, multi_instance):
"""Compute residuals for backpropagation for use with regularize
Ensure that nan never gets computed.
"""
all_A, b, c = multi_instance.A, multi_instance.b, multi_instance.c
m, n, num_instances = multi_instance.get_sizes()
all_tau = u[:, -1]
bad_tau = (all_tau <= 0)
clean_tau = (all_tau > 0) * all_tau + 1 * bad_tau
clean_tau = clean_tau.unsqueeze(1)
x, y = u[:, :n]/clean_tau, u[:, n:n+m]/clean_tau
s = v[:, n:n+m]/clean_tau
# compute primal & dual residuals
if hasattr(multi_instance, 'A_tensor'):
A = multi_instance.A_tensor
else:
all_A_dense = np.stack([curr_A.toarray() for curr_A in all_A])
A = torch.from_numpy(all_A_dense)
multi_instance.A_tensor = A
x_expand, y_expand = x.unsqueeze(2), y.unsqueeze(2)
prim_res = (A @ x_expand).squeeze() + s - b
dual_res = (A.transpose(1, 2) @ y_expand).squeeze() + c
orig_b, orig_c = b, c
if multi_instance.scaled:
D, E, sigma, rho = multi_instance.D, multi_instance.E, \
multi_instance.sigma, multi_instance.rho
prim_res = (D * prim_res) / sigma
dual_res = (E * dual_res) / rho
orig_b = multi_instance.orig_b
orig_c = multi_instance.orig_c
prim_res = prim_res / (1 + orig_b.norm(dim=1).unsqueeze(1))
dual_res = dual_res / (1 + orig_c.norm(dim=1).unsqueeze(1))
bad_tau_filler = bad_tau.unsqueeze(1)
prim_res = prim_res * (~bad_tau_filler) + 0.0 * bad_tau_filler
dual_res = dual_res * (~bad_tau_filler) + 0.0 * bad_tau_filler
return prim_res, dual_res