def __init__()

in ssl/real-dataset/byol_trainer.py [0:0]


    def __init__(self, log_dir, online_network, target_network, predictor, optimizer, predictor_optimizer, device, **params):
        self.online_network = online_network
        self.target_network = target_network
        self.optimizer = optimizer
        self.predictor_optimizer = predictor_optimizer
        self.device = device
        self.predictor = predictor
        self.params = params
        self.writer = SummaryWriter(log_dir)

        self.rand_pred_n_epoch = params["rand_pred_n_epoch"]
        self.rand_pred_n_iter = params["rand_pred_n_iter"]
        self.rand_pred_reg = params["rand_pred_reg"]
        self.max_epochs = params['max_epochs']
        self.m = params['m']
        self.use_order_of_variance = params["use_order_of_variance"]
        self.corr_eigen_decomp = params["corr_eigen_decomp"]
        self.noise_blend = params["noise_blend"]
        self.save_per_epoch = params["save_per_epoch"]
        self.batch_size = params['batch_size']
        self.num_workers = params['num_workers']
        self.checkpoint_interval = params['checkpoint_interval']
        self.target_noise = params['target_noise']
        self.predictor_init = params["predictor_init"]
        self.predictor_reg = params["predictor_reg"]
        self.predictor_eig = params["predictor_eig"]
        self.predictor_freq = params["predictor_freq"]
        self.predictor_rank = params["predictor_rank"]
        self.predictor_eps = params["predictor_eps"]
        self.dyn_time = params["dyn_time"]
        self.dyn_zero_mean = params["dyn_zero_mean"]
        self.dyn_reg = params["dyn_reg"]
        self.dyn_noise = params["dyn_noise"]
        self.dyn_lambda = params["dyn_lambda"]
        self.dyn_sym = params["dyn_sym"]
        self.dyn_psd = params["dyn_psd"]
        self.dyn_eps = params["dyn_eps"]
        self.dyn_eps_inside = params["dyn_eps_inside"]
        self.dyn_bn = params["dyn_bn"]
        self.dyn_convert = params["dyn_convert"]
        self.dyn_diagonalize = params["dyn_diagonalize"]
        self.balance_type = params["balance_type"]
        self.evaluator = params["evaluator"]
        self.solve_direction = params["solve_direction"]
        self.corr_collect = params["corr_collect"]
        self.n_corr = params["n_corr"]
        self.use_l2_normalization = params["use_l2_normalization"]
        self.predictor_wd = params["predictor_wd"]
        self.init_rand_pred = params["init_rand_pred"]
        _create_model_training_folder(self.writer, files_to_same=["./config/byol_config.yaml", "main.py", 'byol_trainer.py',
                                                                  "./models/mlp_head.py"])

        self.predictor_signaling = False
        self.predictor_signaling_2 = False

        self.cum_corr = Accumulator(dyn_lambda=self.dyn_lambda)
        self.cum_cross_corr = Accumulator(dyn_lambda=self.dyn_lambda)
        self.cum_mean1 = Accumulator(dyn_lambda=self.dyn_lambda)
        self.cum_mean2 = Accumulator(dyn_lambda=self.dyn_lambda)

        if self.dyn_noise is not None:
            self.skew = torch.randn(128, 128).to(device=device)
            self.skew = (self.skew - self.skew.t()) * self.dyn_noise

        if self.predictor_reg == "partition":
            # random partition.
            self.partition_w = torch.randn(128, self.n_corr).to(device=device)
            # accumulate according to random partitions. 
            self.cum_corrs_pos = [Accumulator(dyn_lambda=self.dyn_lambda) for i in range(self.n_corr)]
            self.cum_corrs_neg = [Accumulator(dyn_lambda=self.dyn_lambda) for i in range(self.n_corr)]

            self.counts_pos = [0 for i in range(self.n_corr)]
            self.counts_neg = [0 for i in range(self.n_corr)]