in autoencoder.py [0:0]
def train(self, loss_func, stop_early=False, log_frequency=None):
self.encoder.train().to(self.device)
self.decoder.train().to(self.device)
params = list(self.encoder.parameters()) + list(self.decoder.parameters())
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
if log_frequency is None:
log_frequency = self.set_log_frequency()
for epoch in range(self.n_epochs):
running_loss = 0.0
print(f"Epoch {epoch}")
self.log_train_val_loss(loss_func)
for i, (x1, x2, params) in enumerate(self.data.train_loader):
print(f"Training batch {i}", end="\r")
x1 = x1.to(device=self.device)
x2 = x2.to(device=self.device)
angles = self.get_angles(params)
angles = angles.to(device=self.device)
optimizer.zero_grad()
loss = loss_func(x1, x2, angles)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % log_frequency == (log_frequency - 1):
print(f"Running loss: {running_loss / log_frequency:0.3e}")
running_loss = 0.0
if stop_early:
return None
train_loss, valid_loss = self.log_train_val_loss(loss_func)
self.copy_models_validation(valid_loss)
# test loss per sample (using batch size 1)
self.final_test_loss = self.compute_total_loss(
self.data.test_loader_batch_1, loss_func
)
print(f"Test Loss: {self.final_test_loss:0.3e}")