def train()

in ssl/real-dataset/simclr_trainer.py [0:0]


    def train(self, train_dataset):
        train_loader = DataLoader(train_dataset, batch_size=self.params["batch_size"] * torch.cuda.device_count(),
                                  num_workers=self.params["num_workers"], drop_last=True, shuffle=False)

        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
        if not os.path.exists(model_checkpoints_folder):
            os.mkdir(model_checkpoints_folder)

        self.save_model(os.path.join(model_checkpoints_folder, 'model_000.pth'))

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(train_loader), eta_min=0,
                                                               last_epoch=-1)

        margin = self.params["grad_combination_margin"]
        if margin is not None:
            matcher = re.compile(r"encoder.(\d+)")
            layers = dict()
            for name, _ in self.model.named_parameters():
                # print(f"{name}: {params.size()}")
                m = matcher.match(name)
                if m is None:
                    l = 10
                else:
                    l = int(m.group(1))
                layers[name] = l
            unique_entries = sorted(list(set(layers.values())))
            series = np.linspace(margin, 1 - margin, len(unique_entries)) 
            l2ratio = dict(zip(unique_entries, series))
            layer2ratio = { name : l2ratio[l] for name, l in layers.items() }

            log.info(f"Gradient margin: {margin}")
            for name, r in layer2ratio.items():
                log.info(f"  {name}: {r}")
        else:
            log.info("No gradient margin")

        n_iter = 0
        alpha = self.params["noise_blend"]

        for epoch_counter in range(self.params['max_epochs']):
            loss_record = []
            suffix = str(epoch_counter).zfill(3)

            # Add noise to weight once in a while
            if alpha > 0:
                for name, p in self.model.named_parameters():
                    with torch.no_grad():
                        if len(p.size()) < 2:
                            continue
                        w = torch.zeros_like(p, device=p.get_device())
                        torch.nn.init.xavier_uniform_(w)
                        p[:] = (1 - alpha) * p[:] + alpha * w

            for (xis, xjs, xs), _ in train_loader:
                xis = xis.to(self.device)
                xjs = xjs.to(self.device)

                if self.nt_xent_criterion.need_unaug_data():
                    xs = xs.to(self.device)
                else:
                    xs = None

                loss, loss_intra = self._step(self.model, xis, xjs, xs, n_iter)

                # if n_iter % self.params['log_every_n_steps'] == 0:
                #     self.writer.add_scalar('train_loss', loss, global_step=n_iter)

                all_loss = loss + loss_intra
                loss_record.append(all_loss.item())

                if margin is not None:
                    # Here we do backward twice for each loss and weight the gradient at different layer differently. 
                    self.optimizer.zero_grad()
                    loss.backward(retain_graph=True)

                    inter_grads = dict()
                    for name, p in self.model.named_parameters():
                        # print(f"{name}: {p.size()}")
                        inter_grads[name] = p.grad.clone()

                    self.optimizer.zero_grad()
                    loss_intra.backward()
                    for name, p in self.model.named_parameters():
                        r = layer2ratio[name]
                        # Lower layer -> high ratio of loss_intra
                        p.grad *= (1 - r)
                        p.grad += inter_grads[name] * r
                else:
                    self.optimizer.zero_grad()
                    all_loss.backward()

                self.optimizer.step()
                n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                scheduler.step()
            self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)

            log.info(f"Epoch {epoch_counter}: numIter: {n_iter} Loss: {np.mean(loss_record)}")
            if self.evaluator is not None:
                best_acc = self.evaluator.eval_model(deepcopy(self.model))
                log.info(f"Epoch {epoch_counter}: best_acc: {best_acc}")

            if epoch_counter % self.params["save_per_epoch"] == 0:
                # save checkpoints
                self.save_model(os.path.join(model_checkpoints_folder, f'model_{suffix}.pth'))