def train()

in autoencoder.py [0:0]


    def train(self, loss_func, stop_early=False, log_frequency=None):
        self.encoder.train().to(self.device)
        self.decoder.train().to(self.device)

        params = list(self.encoder.parameters()) + list(self.decoder.parameters())
        optimizer = torch.optim.Adam(params, lr=self.learning_rate)

        if log_frequency is None:
            log_frequency = self.set_log_frequency()

        for epoch in range(self.n_epochs):

            running_loss = 0.0
            print(f"Epoch {epoch}")
            self.log_train_val_loss(loss_func)
            for i, (x1, x2, params) in enumerate(self.data.train_loader):
                print(f"Training batch {i}", end="\r")
                x1 = x1.to(device=self.device)
                x2 = x2.to(device=self.device)
                angles = self.get_angles(params)
                angles = angles.to(device=self.device)

                optimizer.zero_grad()
                loss = loss_func(x1, x2, angles)

                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                if i % log_frequency == (log_frequency - 1):
                    print(f"Running loss: {running_loss / log_frequency:0.3e}")
                    running_loss = 0.0
                    if stop_early:
                        return None
        train_loss, valid_loss = self.log_train_val_loss(loss_func)
        self.copy_models_validation(valid_loss)
        # test loss per sample (using batch size 1)
        self.final_test_loss = self.compute_total_loss(
            self.data.test_loader_batch_1, loss_func
        )
        print(f"Test Loss: {self.final_test_loss:0.3e}")