in ssl/real-dataset/byol_trainer.py [0:0]
def __init__(self, log_dir, online_network, target_network, predictor, optimizer, predictor_optimizer, device, **params):
self.online_network = online_network
self.target_network = target_network
self.optimizer = optimizer
self.predictor_optimizer = predictor_optimizer
self.device = device
self.predictor = predictor
self.params = params
self.writer = SummaryWriter(log_dir)
self.rand_pred_n_epoch = params["rand_pred_n_epoch"]
self.rand_pred_n_iter = params["rand_pred_n_iter"]
self.rand_pred_reg = params["rand_pred_reg"]
self.max_epochs = params['max_epochs']
self.m = params['m']
self.use_order_of_variance = params["use_order_of_variance"]
self.corr_eigen_decomp = params["corr_eigen_decomp"]
self.noise_blend = params["noise_blend"]
self.save_per_epoch = params["save_per_epoch"]
self.batch_size = params['batch_size']
self.num_workers = params['num_workers']
self.checkpoint_interval = params['checkpoint_interval']
self.target_noise = params['target_noise']
self.predictor_init = params["predictor_init"]
self.predictor_reg = params["predictor_reg"]
self.predictor_eig = params["predictor_eig"]
self.predictor_freq = params["predictor_freq"]
self.predictor_rank = params["predictor_rank"]
self.predictor_eps = params["predictor_eps"]
self.dyn_time = params["dyn_time"]
self.dyn_zero_mean = params["dyn_zero_mean"]
self.dyn_reg = params["dyn_reg"]
self.dyn_noise = params["dyn_noise"]
self.dyn_lambda = params["dyn_lambda"]
self.dyn_sym = params["dyn_sym"]
self.dyn_psd = params["dyn_psd"]
self.dyn_eps = params["dyn_eps"]
self.dyn_eps_inside = params["dyn_eps_inside"]
self.dyn_bn = params["dyn_bn"]
self.dyn_convert = params["dyn_convert"]
self.dyn_diagonalize = params["dyn_diagonalize"]
self.balance_type = params["balance_type"]
self.evaluator = params["evaluator"]
self.solve_direction = params["solve_direction"]
self.corr_collect = params["corr_collect"]
self.n_corr = params["n_corr"]
self.use_l2_normalization = params["use_l2_normalization"]
self.predictor_wd = params["predictor_wd"]
self.init_rand_pred = params["init_rand_pred"]
_create_model_training_folder(self.writer, files_to_same=["./config/byol_config.yaml", "main.py", 'byol_trainer.py',
"./models/mlp_head.py"])
self.predictor_signaling = False
self.predictor_signaling_2 = False
self.cum_corr = Accumulator(dyn_lambda=self.dyn_lambda)
self.cum_cross_corr = Accumulator(dyn_lambda=self.dyn_lambda)
self.cum_mean1 = Accumulator(dyn_lambda=self.dyn_lambda)
self.cum_mean2 = Accumulator(dyn_lambda=self.dyn_lambda)
if self.dyn_noise is not None:
self.skew = torch.randn(128, 128).to(device=device)
self.skew = (self.skew - self.skew.t()) * self.dyn_noise
if self.predictor_reg == "partition":
# random partition.
self.partition_w = torch.randn(128, self.n_corr).to(device=device)
# accumulate according to random partitions.
self.cum_corrs_pos = [Accumulator(dyn_lambda=self.dyn_lambda) for i in range(self.n_corr)]
self.cum_corrs_neg = [Accumulator(dyn_lambda=self.dyn_lambda) for i in range(self.n_corr)]
self.counts_pos = [0 for i in range(self.n_corr)]
self.counts_neg = [0 for i in range(self.n_corr)]