def train()

in complex_shift_autoencoder.py [0:0]


    def train(self, loss_func, learning_rate, n_epochs, log_frequency):
        self.encoder.train()
        self.decoder.train()
        params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + \
                list(self.W_r.parameters()) + list(self.W_i.parameters())
        optimizer = torch.optim.Adam(params, lr=learning_rate)
        train_losses = torch.FloatTensor(n_epochs)
        valid_losses = torch.FloatTensor(n_epochs)
        best_mse = np.inf
        N_pairs = len(self.data.train_loader.dataset)

        for epoch in range(n_epochs):
            epoch_loss = 0
            for i, (x1, x2, angles) in enumerate(self.data.train_loader):
                x1 = x1.to(device=self.device)
                x2 = x2.to(device=self.device)

                optimizer.zero_grad()
                loss = loss_func(x1, x2, angles)

                loss.backward()
                optimizer.step()
                epoch_loss += loss.item() * x1.size(0)
            epoch_loss = epoch_loss / N_pairs
            print(f"Epoch {epoch} Train loss: {epoch_loss:0.3e}")

            valid_mse = (
                self.compute_mean_loss(loss_func, self.data.valid_loader)
                .detach()
                .item()
            )
            train_losses[epoch] = epoch_loss

            if valid_mse < best_mse:
                self.update_state(mse=valid_mse, epoch=epoch)
                best_mse = valid_mse
                file_name = "checkpoint_{}.pth.tar".format(self.save_name)
                self.save_best_checkpoint(
                    out_dir=self.output_dir,
                    file_name=file_name,
                    optimizer_state_dict=optimizer.state_dict(),
                )

            print(f"Epoch {epoch} validation loss: {valid_mse:0.3e}")
            valid_losses[epoch] = valid_mse
        return train_losses.detach().numpy(), valid_losses.detach().numpy()