def compute_w_minimal_space()

in ssl/real-dataset/byol_trainer.py [0:0]


    def compute_w_minimal_space(self, M, M2, w): 
        try:
            # QDQ^t = M
            D, Q = torch.eig(M, eigenvectors=True)
        except RuntimeError:
            import pdb
            pdb.set_trace()
        # The equation is \dot W = -MW - WM + 2 * M2 (when W keeps symmetric)
        # The solution is e^{-Mt} M2 e^{-Mt}
        # If M can be decomposed: M = QDQ^T, then the solution is Qe^{-Dt}Q^T M2 Q e^{-Dt} Q^T
        d = w.size(1)
        eigen_values = D[:,0].clamp(0)
        M2_convert = Q.t() @ M2 @ Q 
        already_psd = False 
        if self.dyn_time is None:
            # integrate things out and you get a matrix which is 2 / (d_i + d_j), where d_i is the eigenvalues. 
            if self.dyn_diagonalize: 
                M2_diag = M2_convert.diag() / (eigen_values + self.dyn_reg / 2)
                if self.dyn_psd is not None:
                    M2_diag.clamp_(0)
                    already_psd = True
                M2_convert = M2_diag.diag()
            else:
                M2_convert /= (eigen_values.view(d, 1) + eigen_values.view(1, d) + self.dyn_reg) / 2
        else:
            eD = (eigen_values * (-self.dyn_time)).exp().diag()
            M2_convert = eD @ M2_convert @ eD

        w = Q @ M2_convert @ Q.t()

        # Project the weight to be PSD?
        # If we choose to make it PSD (or diagonalized), it will be symmetric by default.
        if self.dyn_psd is None:
            return w
        
        if already_psd:
            w_half = Q @ M2_diag.sqrt().diag()
        else:
            D2, Q2 = torch.eig(w, eigenvectors=True)
            eigen_values2 = D2[:,0].clamp(0)

            # w = w_half @ w_half.t() 
            w_half = Q2 @ eigen_values2.sqrt().diag() 

        if self.dyn_psd > 0:
            # then we do a few iterations to make w more precise.
            alpha = 0.01
            err_magnitudes = torch.zeros(self.dyn_psd)
            w_half_hist = []
            for kk in range(self.dyn_psd):
                w = w_half @ w_half.t()
                err = - (M @ w + w @ M) + M2 * 2
                grad = err @ w_half
                err_magnitudes[kk] = grad.norm()
                w_half_hist.append(w_half.clone())
                # print(f"[{kk}] err: {err.norm().item()}, grad: {grad.norm().item()}")
                w_half += alpha * grad

            err_magnitudes[torch.isnan(err_magnitudes)] = 1e38
            best_kk = err_magnitudes.argmin().item()
            w_half = w_half_hist[best_kk]
        w = w_half @ w_half.t()
        return w