def _train()

in crop_yield_prediction/models/deep_gaussian_process/base.py [0:0]


    def _train(self, train_images, train_yields, val_images, val_yields, train_steps,
               batch_size, starter_learning_rate, weight_decay, l1_weight, patience):
        """Defines the training loop for a model
        """

        train_dataset, val_dataset = TensorDataset(train_images, train_yields), TensorDataset(val_images, val_yields)

        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

        optimizer = torch.optim.Adam([pam for pam in self.model.parameters()],
                                     lr=starter_learning_rate,
                                     weight_decay=weight_decay)

        num_epochs = 50
        print(f'Training for {num_epochs} epochs')

        train_scores = defaultdict(list)
        val_scores = defaultdict(list)

        step_number = 0
        min_loss = np.inf
        best_state = self.model.state_dict()

        if patience is not None:
            epochs_without_improvement = 0

        for epoch in range(num_epochs):
            self.model.train()

            # running train and val scores are only for printing out
            # information
            running_train_scores = defaultdict(list)

            for train_x, train_y in train_dataloader:
                optimizer.zero_grad()
                pred_y = self.model(train_x)

                loss, running_train_scores = l1_l2_loss(pred_y, train_y, l1_weight,
                                                        running_train_scores)
                loss.backward()
                optimizer.step()

                train_scores['loss'].append(loss.item())

                step_number += 1

                if step_number in [4000, 20000]:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] /= 10

            train_output_strings = []
            for key, val in running_train_scores.items():
                train_output_strings.append('{}: {}'.format(key, round(np.array(val).mean(), 5)))

            running_val_scores = defaultdict(list)
            self.model.eval()
            with torch.no_grad():
                for val_x, val_y, in val_dataloader:
                    val_pred_y = self.model(val_x)

                    val_loss, running_val_scores = l1_l2_loss(val_pred_y, val_y, l1_weight,
                                                              running_val_scores)

                    val_scores['loss'].append(val_loss.item())

            val_output_strings = []
            for key, val in running_val_scores.items():
                val_output_strings.append('{}: {}'.format(key, round(np.array(val).mean(), 5)))

            print('TRAINING: {}'.format(', '.join(train_output_strings)))
            print('VALIDATION: {}'.format(', '.join(val_output_strings)))

            epoch_val_loss = np.array(running_val_scores['loss']).mean()

            if epoch_val_loss < min_loss:
                best_state = self.model.state_dict()
                min_loss = epoch_val_loss

                if patience is not None:
                    epochs_without_improvement = 0
            elif patience is not None:
                epochs_without_improvement += 1

                if epochs_without_improvement == patience:
                    # revert to the best state dict
                    self.model.load_state_dict(best_state)
                    print('Early stopping!')
                    break

        self.model.load_state_dict(best_state)
        return train_scores, val_scores