in complex_shift_autoencoder.py [0:0]
def train(self, loss_func, learning_rate, n_epochs, log_frequency):
self.encoder.train()
self.decoder.train()
params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + \
list(self.W_r.parameters()) + list(self.W_i.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
train_losses = torch.FloatTensor(n_epochs)
valid_losses = torch.FloatTensor(n_epochs)
best_mse = np.inf
N_pairs = len(self.data.train_loader.dataset)
for epoch in range(n_epochs):
epoch_loss = 0
for i, (x1, x2, angles) in enumerate(self.data.train_loader):
x1 = x1.to(device=self.device)
x2 = x2.to(device=self.device)
optimizer.zero_grad()
loss = loss_func(x1, x2, angles)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * x1.size(0)
epoch_loss = epoch_loss / N_pairs
print(f"Epoch {epoch} Train loss: {epoch_loss:0.3e}")
valid_mse = (
self.compute_mean_loss(loss_func, self.data.valid_loader)
.detach()
.item()
)
train_losses[epoch] = epoch_loss
if valid_mse < best_mse:
self.update_state(mse=valid_mse, epoch=epoch)
best_mse = valid_mse
file_name = "checkpoint_{}.pth.tar".format(self.save_name)
self.save_best_checkpoint(
out_dir=self.output_dir,
file_name=file_name,
optimizer_state_dict=optimizer.state_dict(),
)
print(f"Epoch {epoch} validation loss: {valid_mse:0.3e}")
valid_losses[epoch] = valid_mse
return train_losses.detach().numpy(), valid_losses.detach().numpy()