def train()

in ssl/real-dataset/byol_trainer.py [0:0]


    def train(self, train_dataset):
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size * torch.cuda.device_count(),
                                  num_workers=self.num_workers, drop_last=False, shuffle=True)

        niter = 0
        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')

        self.get_pred_linear_layers()
        if self.rand_pred_n_epoch is not None and self.rand_pred_n_epoch > 0: 
            self.random_init_predictor(self.predictor)

        if self.init_rand_pred:
            log.info("init_rand_pred=True, reinit the predictor before training starts.")
            self.restart_pred_save(model_checkpoints_folder, "000", 0)

        # Add another BN right before predictor (and right after the target network.)
        self.bn_before_online = nn.BatchNorm1d(self.linear_layers[0].weight.size(1), affine=False).to(self.device)
        self.bn_before_target = nn.BatchNorm1d(self.linear_layers[0].weight.size(1), affine=False).to(self.device)

        # self.initializes_target_network()
        # Save initial network for analysis
        self.save_model(os.path.join(model_checkpoints_folder, 'model_000.pth'))

        for epoch_counter in range(1, 1 + self.max_epochs):
            loss_record = []
            suffix = str(epoch_counter).zfill(3)

            if self.rand_pred_n_epoch is not None and self.rand_pred_n_epoch > 0 and epoch_counter % self.rand_pred_n_epoch == 0:
                self.restart_pred_save(model_checkpoints_folder, suffix, niter)

            for (batch_view_1, batch_view_2, _), _ in train_loader:
                if self.rand_pred_n_iter is not None and self.rand_pred_n_iter > 0 and niter % self.rand_pred_n_iter == 0:
                    self.restart_predictor(False, niter)
                    predictor_path = os.path.join(model_checkpoints_folder, f'reset_predictor_{suffix}_iter{niter}.pth')
                    torch.save({'predictor_state_dict': self.predictor.state_dict()}, predictor_path)

                batch_view_1 = batch_view_1.to(self.device)
                batch_view_2 = batch_view_2.to(self.device)

                loss = self.update(batch_view_1, batch_view_2)
                self.writer.add_scalar('loss', loss, global_step=niter)

                self.optimizer.zero_grad()
                self.predictor_optimizer.zero_grad()
                loss.backward()
                # Add additional grad for regularization, if there is any. 
                self.online_network.projetion.adjust_grad()
                self.optimizer.step()
                self.predictor_optimizer.step()

                # self.online_network.projetion.normalize()

                self.regulate_predictor(self.predictor, niter, epoch_start=False)

                self._update_target_network_parameters()  # update the key encoder
                loss_record.append(loss.item())
                niter += 1

            # Reset the signal so that we can print out some statistics. 
            self.predictor_signaling_2 = False

            log.info(f"Epoch {epoch_counter}: numIter: {niter} Loss: {np.mean(loss_record)}")
            if self.evaluator is not None:
                best_acc = self.evaluator.eval_model(deepcopy(self.online_network))
                log.info(f"Epoch {epoch_counter}: best_acc: {best_acc}")

            stats = self.online_network.projetion.get_stats()
            if stats is not None:
                log.info(f"New normalization stats: {stats}")

            if epoch_counter % self.save_per_epoch == 0:
                # save checkpoints
                self.save_model(os.path.join(model_checkpoints_folder, f'model_{suffix}.pth'))
                if self.cum_corr.cumulated is not None:
                    # Save it 
                    corr_path = os.path.join(model_checkpoints_folder, f'corr_{suffix}_iter{niter}.pth')
                    torch.save(
                        {
                            "cumulated_corr": self.cum_corr.cumulated,
                            "cumulated_corr_counter": self.cum_corr.counter,
                            "cumulated_cross_corr": self.cum_cross_corr.cumulated,
                            "cumulated_mean": self.cum_mean1.cumulated,
                            "cumulated_mean_ema": self.cum_mean2.cumulated
                        },
                        corr_path
                    )

        # save checkpoints
        self.save_model(os.path.join(model_checkpoints_folder, 'model_final.pth'))