in ssl/real-dataset/byol_trainer.py [0:0]
def compute_w_corr(self, M):
if self.corr_eigen_decomp:
if not self.predictor_signaling_2:
log.info("compute_w_corr: Use eigen_decomp!")
D, Q = torch.eig(M, eigenvectors=True)
# Only use the real part.
D = D[:,0]
else:
# Just use diagonal element.
if not self.predictor_signaling_2:
log.info("compute_w_corr: No eigen_decomp, just use diagonal elements!")
D = M.diag()
Q = torch.eye(M.size(0)).to(D.device)
# if eigen_values >= 1, scale everything down.
balance_type = self.balance_type
reg = self.dyn_reg
if balance_type == "shrink2mean":
mean_eig = D.mean()
eigen_values = (D - mean_eig) / 2 + mean_eig
elif balance_type == "clamp":
eigen_values = D.clamp(min=0, max=1-reg)
elif balance_type == "boost_scale":
max_eig = D.max()
eigen_values = D.clamp(0) / max_eig
# Going through a concave function (dyn_convert > 1, e.g., 2 or sqrt function) to boost small eigenvalues (while still keep very small one to be 0)
if self.dyn_eps_inside:
# Note that here dyn_eps is allowed to be negative.
eigen_values = (eigen_values + self.dyn_eps).clamp(1e-4).pow(1/self.dyn_convert)
else:
# Note that here dyn_eps is allowed to be negative.
eigen_values = eigen_values.pow(1/self.dyn_convert) + self.dyn_eps
eigen_values = eigen_values.clamp(1e-4)
if not self.predictor_signaling_2:
sorted_values, _ = eigen_values.sort(descending=True)
log.info(f"Compute eigenvalues with boost_scale: Top-5: {sorted_values[:5]}, Bottom-5: {sorted_values[-5:]}")
elif balance_type == "scale":
max_eig = D.max()
if max_eig > 1 - reg:
eigen_values = D / (max_eig + reg)
else:
raise RuntimeError(f"Unkonwn balance_type: {balance_type}")
return Q @ eigen_values.diag() @ Q.t()