def regulate_predictor()

in ssl/real-dataset/byol_trainer.py [0:0]


    def regulate_predictor(self, predictor, niter, epoch_start=False):
        if self.predictor_reg is None:
            return

        linear_layers = self.linear_layers
        if len(linear_layers) == 1:
            assert self.predictor_reg in ["diagonal", "symmetric", "symmetric_norm", "onehalfeig", "minimal_space", "solve", "corr", "partition"], f"predictor_reg: {self.predictor_reg} not valid!" 
            # Make it symmetric
            with torch.no_grad():
                w = linear_layers[0].weight.clone()
                if torch.any(torch.isnan(w)).item() or torch.any(torch.isinf(w)).item():
                    import pdb
                    pdb.set_trace()
                prev_w = w.clone()

                if self.predictor_reg == "diagonal":
                    # Further make it diagonal
                    w = w.diag().diag()
                elif self.predictor_reg == "symmetric":
                    w += w.t()
                    w /= 2
                elif self.predictor_reg == "symmetric_norm":
                    if not self.predictor_signaling_2:
                        log.info(f"Enforce symmetric constraint with unit spectral norm.")
                    w += w.t()
                    w /= 2
                    # Normalize so that the largest positive eigenvalue is 1
                    D, _ = torch.eig(w, eigenvectors=False)
                    max_eigen = D[:,0].max()
                    if max_eigen.abs() > 1e-2:
                        w /= max_eigen
                    
                elif self.predictor_reg == "solve":
                    if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
                        M = self.cum_corr.get()
                        M2 = self.cum_cross_corr.get()

                        if M is not None and M2 is not None:
                            if not self.predictor_signaling_2:
                                log.info(f"Reinitialize predictor weight (assymmetric). freq={self.predictor_freq}, reg={self.dyn_reg}, dir={self.solve_direction}")

                            D, Q = torch.eig(M, eigenvectors=True)
                            inv = Q @ (D[:,0].clamp(0) + self.dyn_reg).pow(-1) @ Q.t()
                            if self.solve_direction == "left":
                                w = inv @ M2
                            else:
                                w = M2 @ inv

                elif self.predictor_reg == "corr":
                    if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
                        M = self.cum_corr.get()
                        if M is not None:
                            if not self.predictor_signaling_2:
                                log.info(f"Set predictor to align with input correlation. zero_mean={self.dyn_zero_mean}, freq={self.predictor_freq}, type={self.balance_type}, pow=1/{self.dyn_convert}, eps={self.dyn_eps}, reg={self.dyn_reg}, noise={self.dyn_noise}, eps_inside={self.dyn_eps_inside}")

                            if self.dyn_zero_mean:
                                mean_f = self.cum_mean1.get()
                                M -= torch.ger(mean_f, mean_f)

                            w = self.compute_w_corr(M)

                            if self.dyn_noise is not None:
                                w += self.skew / (niter + 1)

                elif self.predictor_reg == "directcopy":
                    if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
                        M = self.cum_corr.get()
                        if M is not None:
                            if not self.predictor_signaling_2:
                                log.info(f"Set predictor to be input correlation. zero_mean={self.dyn_zero_mean}, freq={self.predictor_freq}, eps={self.dyn_eps}")

                            if self.dyn_zero_mean:
                                mean_f = self.cum_mean1.get()
                                M -= torch.ger(mean_f, mean_f)

                            w = M + self.dyn_eps * torch.eye(M.size(0), dtype=M.dtype, device=M.device)

                            if self.dyn_noise is not None:
                                w += self.skew / (niter + 1)

                elif self.predictor_reg == "minimal_space":
                    if self.predictor_freq > 0 and niter % self.predictor_freq == 0:
                        M = self.cum_corr.get()
                        M2 = self.cum_cross_corr.get()

                        if M is not None and M2 is not None:
                            # Initialize weight to contain eigenvectors that correspond to lowest few eigenvalues from the input data.
                            if not self.predictor_signaling_2:
                                log.info(f"Reinitialize predictor weight. freq={self.predictor_freq}, dyn: time={self.dyn_time}, zero_mean={self.dyn_zero_mean}, reg={self.dyn_reg}, sym={self.dyn_sym}, lambda={self.dyn_lambda}, make_psd={self.dyn_psd}, diagonize={self.dyn_diagonalize}, before_bn={self.dyn_bn}")

                            if self.dyn_zero_mean:
                                mean_f = self.cum_mean1.get()
                                mean_f_ema = self.cum_mean2.get()

                                M -= torch.ger(mean_f, mean_f)
                                M2 -= torch.ger(mean_f_ema, mean_f)

                            if self.dyn_sym:
                                M2 += M2.t()
                                M2 /= 2

                            w = self.compute_w_minimal_space(M, M2, w)

                            self.cum_corr.reset()
                            self.cum_cross_corr.reset()
                            self.cum_mean1.reset()
                            self.cum_mean2.reset()
                    else:
                        if self.dyn_sym:
                            # Just make it symmetric
                            w += w.t()
                            w /= 2
                    
                elif self.predictor_reg == "onehalfeig":
                    # only run it with epoch_start 
                    if epoch_start or (self.predictor_freq > 0 and niter % self.predictor_freq == 0):
                        if not self.predictor_signaling_2 or epoch_start:
                            log.info(f"Reinit predictor weight with {self.predictor_reg}. epoch_start={epoch_start}, freq={self.predictor_freq}")
                        d = w.size(0)
                        Q = torch.FloatTensor(random_orthonormal_matrix(dim=d))
                        # Pick predictor_rank of the vectors.
                        w.fill_(0)
                        r = int(self.predictor_rank * d)
                        for i in range(r):
                            w += Q[:,i] @ Q[:,i].t() 
                        w *= self.predictor_eig
                        w += torch.eye(d).to(w.get_device()) * self.predictor_eps
                    else:
                        # Just make it symmetric
                        w += w.t()
                        w /= 2

                elif self.predictor_reg == "partition":
                    # Don't need to do anything. 
                    pass
                else:
                    raise RuntimeError(f"Unknown partition_reg: {self.partition_reg}")

                if torch.any(torch.isnan(w)).item() or torch.any(torch.isinf(w)).item():
                    import pdb
                    pdb.set_trace()

                linear_layers[0].weight.copy_(w)

        elif len(linear_layers) == 2:
            # Two layer without batch norm. Make it symmetric!
            assert self.predictor_reg in ["symmetric", "symmetric_row_norm"], f"predictor_reg: {self.predictor_reg} not valid!"  
            with torch.no_grad():
                # hidden_size x input_size
                w0 = linear_layers[0].weight.clone()
                w1 = linear_layers[1].weight.clone()
                w = (w0 + w1.t()) / 2

                if self.predictor_reg == "symmetric_row_norm":
                    # Keep norm for each hidden layer output.  
                    w /= w.norm(dim=1, keepdim=True)
                    
                linear_layers[0].weight.copy_(w)
                linear_layers[1].weight.copy_(w.t())

        if self.predictor_wd is not None:
            if not self.predictor_signaling_2:
               log.info(f"Apply predictor weight decay: {self.predictor_wd}")

            with torch.no_grad():
                for l in linear_layers:
                    l.weight *= (1 - self.predictor_wd)
            
        self.predictor_signaling_2 = True