in ssl/real-dataset/byol_trainer.py [0:0]
def compute_w_minimal_space(self, M, M2, w):
try:
# QDQ^t = M
D, Q = torch.eig(M, eigenvectors=True)
except RuntimeError:
import pdb
pdb.set_trace()
# The equation is \dot W = -MW - WM + 2 * M2 (when W keeps symmetric)
# The solution is e^{-Mt} M2 e^{-Mt}
# If M can be decomposed: M = QDQ^T, then the solution is Qe^{-Dt}Q^T M2 Q e^{-Dt} Q^T
d = w.size(1)
eigen_values = D[:,0].clamp(0)
M2_convert = Q.t() @ M2 @ Q
already_psd = False
if self.dyn_time is None:
# integrate things out and you get a matrix which is 2 / (d_i + d_j), where d_i is the eigenvalues.
if self.dyn_diagonalize:
M2_diag = M2_convert.diag() / (eigen_values + self.dyn_reg / 2)
if self.dyn_psd is not None:
M2_diag.clamp_(0)
already_psd = True
M2_convert = M2_diag.diag()
else:
M2_convert /= (eigen_values.view(d, 1) + eigen_values.view(1, d) + self.dyn_reg) / 2
else:
eD = (eigen_values * (-self.dyn_time)).exp().diag()
M2_convert = eD @ M2_convert @ eD
w = Q @ M2_convert @ Q.t()
# Project the weight to be PSD?
# If we choose to make it PSD (or diagonalized), it will be symmetric by default.
if self.dyn_psd is None:
return w
if already_psd:
w_half = Q @ M2_diag.sqrt().diag()
else:
D2, Q2 = torch.eig(w, eigenvectors=True)
eigen_values2 = D2[:,0].clamp(0)
# w = w_half @ w_half.t()
w_half = Q2 @ eigen_values2.sqrt().diag()
if self.dyn_psd > 0:
# then we do a few iterations to make w more precise.
alpha = 0.01
err_magnitudes = torch.zeros(self.dyn_psd)
w_half_hist = []
for kk in range(self.dyn_psd):
w = w_half @ w_half.t()
err = - (M @ w + w @ M) + M2 * 2
grad = err @ w_half
err_magnitudes[kk] = grad.norm()
w_half_hist.append(w_half.clone())
# print(f"[{kk}] err: {err.norm().item()}, grad: {grad.norm().item()}")
w_half += alpha * grad
err_magnitudes[torch.isnan(err_magnitudes)] = 1e38
best_kk = err_magnitudes.argmin().item()
w_half = w_half_hist[best_kk]
w = w_half @ w_half.t()
return w