in ssl/real-dataset/byol_trainer.py [0:0]
def regulate_predictor(self, predictor, niter, epoch_start=False):
if self.predictor_reg is None:
return
linear_layers = self.linear_layers
if len(linear_layers) == 1:
assert self.predictor_reg in ["diagonal", "symmetric", "symmetric_norm", "onehalfeig", "minimal_space", "solve", "corr", "partition"], f"predictor_reg: {self.predictor_reg} not valid!"
# Make it symmetric
with torch.no_grad():
w = linear_layers[0].weight.clone()
if torch.any(torch.isnan(w)).item() or torch.any(torch.isinf(w)).item():
import pdb
pdb.set_trace()
prev_w = w.clone()
if self.predictor_reg == "diagonal":
# Further make it diagonal
w = w.diag().diag()
elif self.predictor_reg == "symmetric":
w += w.t()
w /= 2
elif self.predictor_reg == "symmetric_norm":
if not self.predictor_signaling_2:
log.info(f"Enforce symmetric constraint with unit spectral norm.")
w += w.t()
w /= 2
# Normalize so that the largest positive eigenvalue is 1
D, _ = torch.eig(w, eigenvectors=False)
max_eigen = D[:,0].max()
if max_eigen.abs() > 1e-2:
w /= max_eigen
elif self.predictor_reg == "solve":
if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
M = self.cum_corr.get()
M2 = self.cum_cross_corr.get()
if M is not None and M2 is not None:
if not self.predictor_signaling_2:
log.info(f"Reinitialize predictor weight (assymmetric). freq={self.predictor_freq}, reg={self.dyn_reg}, dir={self.solve_direction}")
D, Q = torch.eig(M, eigenvectors=True)
inv = Q @ (D[:,0].clamp(0) + self.dyn_reg).pow(-1) @ Q.t()
if self.solve_direction == "left":
w = inv @ M2
else:
w = M2 @ inv
elif self.predictor_reg == "corr":
if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
M = self.cum_corr.get()
if M is not None:
if not self.predictor_signaling_2:
log.info(f"Set predictor to align with input correlation. zero_mean={self.dyn_zero_mean}, freq={self.predictor_freq}, type={self.balance_type}, pow=1/{self.dyn_convert}, eps={self.dyn_eps}, reg={self.dyn_reg}, noise={self.dyn_noise}, eps_inside={self.dyn_eps_inside}")
if self.dyn_zero_mean:
mean_f = self.cum_mean1.get()
M -= torch.ger(mean_f, mean_f)
w = self.compute_w_corr(M)
if self.dyn_noise is not None:
w += self.skew / (niter + 1)
elif self.predictor_reg == "directcopy":
if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
M = self.cum_corr.get()
if M is not None:
if not self.predictor_signaling_2:
log.info(f"Set predictor to be input correlation. zero_mean={self.dyn_zero_mean}, freq={self.predictor_freq}, eps={self.dyn_eps}")
if self.dyn_zero_mean:
mean_f = self.cum_mean1.get()
M -= torch.ger(mean_f, mean_f)
w = M + self.dyn_eps * torch.eye(M.size(0), dtype=M.dtype, device=M.device)
if self.dyn_noise is not None:
w += self.skew / (niter + 1)
elif self.predictor_reg == "minimal_space":
if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
M = self.cum_corr.get()
M2 = self.cum_cross_corr.get()
if M is not None and M2 is not None:
# Initialize weight to contain eigenvectors that correspond to lowest few eigenvalues from the input data.
if not self.predictor_signaling_2:
log.info(f"Reinitialize predictor weight. freq={self.predictor_freq}, dyn: time={self.dyn_time}, zero_mean={self.dyn_zero_mean}, reg={self.dyn_reg}, sym={self.dyn_sym}, lambda={self.dyn_lambda}, make_psd={self.dyn_psd}, diagonize={self.dyn_diagonalize}, before_bn={self.dyn_bn}")
if self.dyn_zero_mean:
mean_f = self.cum_mean1.get()
mean_f_ema = self.cum_mean2.get()
M -= torch.ger(mean_f, mean_f)
M2 -= torch.ger(mean_f_ema, mean_f)
if self.dyn_sym:
M2 += M2.t()
M2 /= 2
w = self.compute_w_minimal_space(M, M2, w)
self.cum_corr.reset()
self.cum_cross_corr.reset()
self.cum_mean1.reset()
self.cum_mean2.reset()
else:
if self.dyn_sym:
# Just make it symmetric
w += w.t()
w /= 2
elif self.predictor_reg == "onehalfeig":
# only run it with epoch_start
if epoch_start or (self.predictor_freq > 0 and niter % self.predictor_freq == 0):
if not self.predictor_signaling_2 or epoch_start:
log.info(f"Reinit predictor weight with {self.predictor_reg}. epoch_start={epoch_start}, freq={self.predictor_freq}")
d = w.size(0)
Q = torch.FloatTensor(random_orthonormal_matrix(dim=d))
# Pick predictor_rank of the vectors.
w.fill_(0)
r = int(self.predictor_rank * d)
for i in range(r):
w += Q[:,i] @ Q[:,i].t()
w *= self.predictor_eig
w += torch.eye(d).to(w.get_device()) * self.predictor_eps
else:
# Just make it symmetric
w += w.t()
w /= 2
elif self.predictor_reg == "partition":
# Don't need to do anything.
pass
else:
raise RuntimeError(f"Unknown partition_reg: {self.partition_reg}")
if torch.any(torch.isnan(w)).item() or torch.any(torch.isinf(w)).item():
import pdb
pdb.set_trace()
linear_layers[0].weight.copy_(w)
elif len(linear_layers) == 2:
# Two layer without batch norm. Make it symmetric!
assert self.predictor_reg in ["symmetric", "symmetric_row_norm"], f"predictor_reg: {self.predictor_reg} not valid!"
with torch.no_grad():
# hidden_size x input_size
w0 = linear_layers[0].weight.clone()
w1 = linear_layers[1].weight.clone()
w = (w0 + w1.t()) / 2
if self.predictor_reg == "symmetric_row_norm":
# Keep norm for each hidden layer output.
w /= w.norm(dim=1, keepdim=True)
linear_layers[0].weight.copy_(w)
linear_layers[1].weight.copy_(w.t())
if self.predictor_wd is not None:
if not self.predictor_signaling_2:
log.info(f"Apply predictor weight decay: {self.predictor_wd}")
with torch.no_grad():
for l in linear_layers:
l.weight *= (1 - self.predictor_wd)
self.predictor_signaling_2 = True