def train()

in cci_variational_autoencoder.py [0:0]


    def train(self, stop_early=False, log_frequency=None, track_losses=True):
        """Trains controlled capacity beta vae (CCI VAE)
        https://arxiv.org/abs/1804.03599

        Learning rate used in the paper is 5e-4

        If verbose is False, previous loss print is overridden
        If stop_early is True, training stops after first logged loss. 
        This is useful for testing.
        """
        self.model.train().to(self.device)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

        c_step_size = (self.c_max - self.c) / self.n_epochs

        if log_frequency is None:
            log_frequency = self.set_log_frequency()

        for epoch in range(self.n_epochs):
            running_loss = 0.0
            print(f"Epoch {epoch}")
            if track_losses:
                self.log_train_val_loss()
            running_loss = 0.0
            running_reconstruction_loss, running_kl_divergence = 0.0, 0.0
            # update controlled capacity parameter
            self.c += c_step_size
            for i, (x1, _, _) in enumerate(self.data.train_loader):
                x1 = x1.to(device=self.device)

                optimizer.zero_grad()
                reconstruction_loss, kl_divergence = self.compute_loss(x1)

                loss = reconstruction_loss + self.beta * (kl_divergence - self.c).abs()

                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                running_reconstruction_loss += (
                    reconstruction_loss.cpu().detach().numpy()
                )
                running_kl_divergence += kl_divergence.cpu().detach().numpy()

                if i % log_frequency == (log_frequency - 1):
                    normalized_loss = running_loss / log_frequency
                    normalized_reconstruction_loss = (
                        running_reconstruction_loss / log_frequency
                    )
                    normalized_kl_divergence = running_kl_divergence / log_frequency
                    print(f"Running Total Loss: {normalized_loss:0.3e}")
                    print(
                        f"Running Reconstruction Loss: {normalized_reconstruction_loss:0.3e}"
                        f" KL Divergence: {normalized_kl_divergence:0.3e}"
                    )
                    self.kl_losses.append(normalized_kl_divergence)
                    self.reconstruction_losses.append(normalized_reconstruction_loss)

                    running_loss = 0.0
                    running_reconstruction_loss = 0.0
                    running_kl_divergence = 0.0
                    if stop_early:
                        return None

        if track_losses:
            train_loss, valid_loss = self.log_train_val_loss()
            self.copy_models_validation(valid_loss)
            # compute test loss per sample
            self.final_test_loss = self.compute_total_loss(
                self.data.test_loader_batch_1
            )
            print(f"Test Loss: {self.final_test_loss:0.3e}")