def _normalize_A_sparse()

in automl21/scs_neural/solver/neural_scs_batched.py [0:0]


    def _normalize_A_sparse(self, all_A, boundaries, scale):
        """
           Normalize the A matrix. This code comes from:
           https://github.com/bodono/scs-python/blob/master/test/test_scs_python_linsys.py
        """
        updated_A, all_instance_D, all_instance_E = [], [], []
        all_instance_row_norm, all_instance_col_norm = [], []

        for A in all_A:
            m, n = A.shape
            D_all = np.ones(m)
            E_all = np.ones(n)

            min_scale, max_scale = (1e-4, 1e4)
            n_passes = 10

            for i in range(n_passes):
                D = np.sqrt(sla.norm(A, float('inf'), axis=1))
                E = np.sqrt(sla.norm(A, float('inf'), axis=0))
                D[D < min_scale] = 1.0
                E[E < min_scale] = 1.0
                D[D > max_scale] = max_scale
                E[E > max_scale] = max_scale
                start = boundaries[0]
                for delta in boundaries[1:]:
                    D[start:start+delta] = D[start:start+delta].mean()
                    start += delta
                A = sp.diags(1/D).dot(A).dot(sp.diags(1/E))
                D_all *= D
                E_all *= E

            mean_row_norm = sla.norm(A, 2, axis=1).mean()
            mean_col_norm = sla.norm(A, 2, axis=0).mean()
            A *= scale
            updated_A.append(A)
            all_instance_D.append(D_all)
            all_instance_E.append(E_all)
            all_instance_row_norm.append(mean_row_norm)
            all_instance_col_norm.append(mean_col_norm)

        D_final = np.stack(all_instance_D)
        E_final = np.stack(all_instance_E)
        mean_row_norm = np.stack(all_instance_row_norm)
        mean_col_norm = np.stack(all_instance_col_norm)

        D_tensor, E_tensor = torch.from_numpy(D_final), torch.from_numpy(E_final)
        row_norm, col_norm = torch.from_numpy(mean_row_norm), torch.from_numpy(mean_col_norm)
        return updated_A, D_tensor, E_tensor, row_norm, col_norm