in cci_variational_autoencoder.py [0:0]
def train(self, stop_early=False, log_frequency=None, track_losses=True):
"""Trains controlled capacity beta vae (CCI VAE)
https://arxiv.org/abs/1804.03599
Learning rate used in the paper is 5e-4
If verbose is False, previous loss print is overridden
If stop_early is True, training stops after first logged loss.
This is useful for testing.
"""
self.model.train().to(self.device)
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
c_step_size = (self.c_max - self.c) / self.n_epochs
if log_frequency is None:
log_frequency = self.set_log_frequency()
for epoch in range(self.n_epochs):
running_loss = 0.0
print(f"Epoch {epoch}")
if track_losses:
self.log_train_val_loss()
running_loss = 0.0
running_reconstruction_loss, running_kl_divergence = 0.0, 0.0
# update controlled capacity parameter
self.c += c_step_size
for i, (x1, _, _) in enumerate(self.data.train_loader):
x1 = x1.to(device=self.device)
optimizer.zero_grad()
reconstruction_loss, kl_divergence = self.compute_loss(x1)
loss = reconstruction_loss + self.beta * (kl_divergence - self.c).abs()
loss.backward()
optimizer.step()
running_loss += loss.item()
running_reconstruction_loss += (
reconstruction_loss.cpu().detach().numpy()
)
running_kl_divergence += kl_divergence.cpu().detach().numpy()
if i % log_frequency == (log_frequency - 1):
normalized_loss = running_loss / log_frequency
normalized_reconstruction_loss = (
running_reconstruction_loss / log_frequency
)
normalized_kl_divergence = running_kl_divergence / log_frequency
print(f"Running Total Loss: {normalized_loss:0.3e}")
print(
f"Running Reconstruction Loss: {normalized_reconstruction_loss:0.3e}"
f" KL Divergence: {normalized_kl_divergence:0.3e}"
)
self.kl_losses.append(normalized_kl_divergence)
self.reconstruction_losses.append(normalized_reconstruction_loss)
running_loss = 0.0
running_reconstruction_loss = 0.0
running_kl_divergence = 0.0
if stop_early:
return None
if track_losses:
train_loss, valid_loss = self.log_train_val_loss()
self.copy_models_validation(valid_loss)
# compute test loss per sample
self.final_test_loss = self.compute_total_loss(
self.data.test_loader_batch_1
)
print(f"Test Loss: {self.final_test_loss:0.3e}")