complex_shift_autoencoder.py [149:189]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        optimizer = torch.optim.Adam(params, lr=learning_rate)
        train_losses = torch.FloatTensor(n_epochs)
        valid_losses = torch.FloatTensor(n_epochs)
        best_mse = np.inf
        N_pairs = len(self.data.train_loader.dataset)

        for epoch in range(n_epochs):
            epoch_loss = 0
            for i, (x1, x2, angles) in enumerate(self.data.train_loader):
                x1 = x1.to(device=self.device)
                x2 = x2.to(device=self.device)

                optimizer.zero_grad()
                loss = loss_func(x1, x2, angles)

                loss.backward()
                optimizer.step()
                epoch_loss += loss.item() * x1.size(0)
            epoch_loss = epoch_loss / N_pairs
            print(f"Epoch {epoch} Train loss: {epoch_loss:0.3e}")

            valid_mse = (
                self.compute_mean_loss(loss_func, self.data.valid_loader)
                .detach()
                .item()
            )
            train_losses[epoch] = epoch_loss

            if valid_mse < best_mse:
                self.update_state(mse=valid_mse, epoch=epoch)
                best_mse = valid_mse
                file_name = "checkpoint_{}.pth.tar".format(self.save_name)
                self.save_best_checkpoint(
                    out_dir=self.output_dir,
                    file_name=file_name,
                    optimizer_state_dict=optimizer.state_dict(),
                )

            print(f"Epoch {epoch} validation loss: {valid_mse:0.3e}")
            valid_losses[epoch] = valid_mse
        return train_losses.detach().numpy(), valid_losses.detach().numpy()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



weakly_complex_shift_autoencoder.py [116:162]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        optimizer = torch.optim.Adam(params, lr=learning_rate)
        train_losses = torch.FloatTensor(n_epochs)
        valid_losses = torch.FloatTensor(n_epochs)
        best_mse = np.inf
        N_pairs = len(self.data.train_loader.dataset)
        for epoch in range(n_epochs):
            epoch_loss = 0
            for i, (x1, x2, angles) in enumerate(self.data.train_loader):
                x1 = x1.to(device=self.device)
                x2 = x2.to(device=self.device)

                optimizer.zero_grad()
                loss = loss_func(x1, x2, angles)

                loss.backward()
                optimizer.step()
                epoch_loss += loss.item() * x1.size(0)
            epoch_loss = epoch_loss / N_pairs
            print(f"Epoch {epoch} Train loss: {epoch_loss:0.3e}")

            valid_mse = (
                self.compute_mean_loss(loss_func, self.data.valid_loader)
                .detach()
                .item()
            )

            # train_mse = (
            #     self.compute_mean_loss(loss_func, self.data.train_loader)
            #     .detach()
            #     .item()
            # )
            # train_losses[epoch] = train_mse
            train_losses[epoch] = epoch_loss

            if valid_mse < best_mse:
                self.update_state(mse=valid_mse, epoch=epoch)
                best_mse = valid_mse
                file_name = "checkpoint_{}.pth.tar".format(self.save_name)
                self.save_best_checkpoint(
                    out_dir=self.output_dir,
                    file_name=file_name,
                    optimizer_state_dict=optimizer.state_dict(),
                )

            print(f"Epoch {epoch} validation loss: {valid_mse:0.3e}")
            valid_losses[epoch] = valid_mse
        return train_losses.detach().numpy(), valid_losses.detach().numpy()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



