in ssl/real-dataset/byol_trainer.py [0:0]
def train(self, train_dataset):
train_loader = DataLoader(train_dataset, batch_size=self.batch_size * torch.cuda.device_count(),
num_workers=self.num_workers, drop_last=False, shuffle=True)
niter = 0
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
self.get_pred_linear_layers()
if self.rand_pred_n_epoch is not None and self.rand_pred_n_epoch > 0:
self.random_init_predictor(self.predictor)
if self.init_rand_pred:
log.info("init_rand_pred=True, reinit the predictor before training starts.")
self.restart_pred_save(model_checkpoints_folder, "000", 0)
# Add another BN right before predictor (and right after the target network.)
self.bn_before_online = nn.BatchNorm1d(self.linear_layers[0].weight.size(1), affine=False).to(self.device)
self.bn_before_target = nn.BatchNorm1d(self.linear_layers[0].weight.size(1), affine=False).to(self.device)
# self.initializes_target_network()
# Save initial network for analysis
self.save_model(os.path.join(model_checkpoints_folder, 'model_000.pth'))
for epoch_counter in range(1, 1 + self.max_epochs):
loss_record = []
suffix = str(epoch_counter).zfill(3)
if self.rand_pred_n_epoch is not None and self.rand_pred_n_epoch > 0 and epoch_counter % self.rand_pred_n_epoch == 0:
self.restart_pred_save(model_checkpoints_folder, suffix, niter)
for (batch_view_1, batch_view_2, _), _ in train_loader:
if self.rand_pred_n_iter is not None and self.rand_pred_n_iter > 0 and niter % self.rand_pred_n_iter == 0:
self.restart_predictor(False, niter)
predictor_path = os.path.join(model_checkpoints_folder, f'reset_predictor_{suffix}_iter{niter}.pth')
torch.save({'predictor_state_dict': self.predictor.state_dict()}, predictor_path)
batch_view_1 = batch_view_1.to(self.device)
batch_view_2 = batch_view_2.to(self.device)
loss = self.update(batch_view_1, batch_view_2)
self.writer.add_scalar('loss', loss, global_step=niter)
self.optimizer.zero_grad()
self.predictor_optimizer.zero_grad()
loss.backward()
# Add additional grad for regularization, if there is any.
self.online_network.projetion.adjust_grad()
self.optimizer.step()
self.predictor_optimizer.step()
# self.online_network.projetion.normalize()
self.regulate_predictor(self.predictor, niter, epoch_start=False)
self._update_target_network_parameters() # update the key encoder
loss_record.append(loss.item())
niter += 1
# Reset the signal so that we can print out some statistics.
self.predictor_signaling_2 = False
log.info(f"Epoch {epoch_counter}: numIter: {niter} Loss: {np.mean(loss_record)}")
if self.evaluator is not None:
best_acc = self.evaluator.eval_model(deepcopy(self.online_network))
log.info(f"Epoch {epoch_counter}: best_acc: {best_acc}")
stats = self.online_network.projetion.get_stats()
if stats is not None:
log.info(f"New normalization stats: {stats}")
if epoch_counter % self.save_per_epoch == 0:
# save checkpoints
self.save_model(os.path.join(model_checkpoints_folder, f'model_{suffix}.pth'))
if self.cum_corr.cumulated is not None:
# Save it
corr_path = os.path.join(model_checkpoints_folder, f'corr_{suffix}_iter{niter}.pth')
torch.save(
{
"cumulated_corr": self.cum_corr.cumulated,
"cumulated_corr_counter": self.cum_corr.counter,
"cumulated_cross_corr": self.cum_cross_corr.cumulated,
"cumulated_mean": self.cum_mean1.cumulated,
"cumulated_mean_ema": self.cum_mean2.cumulated
},
corr_path
)
# save checkpoints
self.save_model(os.path.join(model_checkpoints_folder, 'model_final.pth'))