def compute_w_corr()

in ssl/real-dataset/byol_trainer.py [0:0]


    def compute_w_corr(self, M):
        if self.corr_eigen_decomp:
            if not self.predictor_signaling_2:
                log.info("compute_w_corr: Use eigen_decomp!")
            D, Q = torch.eig(M, eigenvectors=True)
            # Only use the real part. 
            D = D[:,0]
        else:
            # Just use diagonal element. 
            if not self.predictor_signaling_2:
                log.info("compute_w_corr: No eigen_decomp, just use diagonal elements!")
            D = M.diag()
            Q = torch.eye(M.size(0)).to(D.device)
             
        # if eigen_values >= 1, scale everything down. 
        balance_type = self.balance_type
        reg = self.dyn_reg

        if balance_type == "shrink2mean":
            mean_eig = D.mean()
            eigen_values = (D - mean_eig) / 2 + mean_eig
        elif balance_type == "clamp":
            eigen_values = D.clamp(min=0, max=1-reg)
        elif balance_type == "boost_scale":
            max_eig = D.max()
            eigen_values = D.clamp(0) / max_eig
            # Going through a concave function (dyn_convert > 1, e.g., 2 or sqrt function) to boost small eigenvalues (while still keep very small one to be 0)
            if self.dyn_eps_inside:
                # Note that here dyn_eps is allowed to be negative.
                eigen_values = (eigen_values + self.dyn_eps).clamp(1e-4).pow(1/self.dyn_convert)
            else:
                # Note that here dyn_eps is allowed to be negative.
                eigen_values = eigen_values.pow(1/self.dyn_convert) + self.dyn_eps
                eigen_values = eigen_values.clamp(1e-4)
            if not self.predictor_signaling_2:
                sorted_values, _ = eigen_values.sort(descending=True)
                log.info(f"Compute eigenvalues with boost_scale: Top-5: {sorted_values[:5]}, Bottom-5: {sorted_values[-5:]}")

        elif balance_type == "scale":
            max_eig = D.max()
            if max_eig > 1 - reg:
                eigen_values = D / (max_eig + reg)
        else:
            raise RuntimeError(f"Unkonwn balance_type: {balance_type}")

        return Q @ eigen_values.diag() @ Q.t()