def _train_net()

in tabular/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py [0:0]


    def _train_net(self,
                   train_dataset,
                   loss_kwargs,
                   batch_size,
                   num_epochs,
                   epochs_wo_improve,
                   val_dataset=None,
                   time_limit=None,
                   reporter=None,
                   verbosity=2):
        import torch
        start_time = time.time()
        logging.debug("initializing neural network...")
        self.model.init_params()
        logging.debug("initialized")
        train_dataloader = train_dataset.build_loader(batch_size, self.num_dataloading_workers, is_test=False)

        if isinstance(loss_kwargs.get('loss_function', 'auto'), str) and loss_kwargs.get('loss_function', 'auto') == 'auto':
            loss_kwargs['loss_function'] = self._get_default_loss_function()

        if val_dataset is not None:
            y_val = val_dataset.get_labels()
            if y_val.ndim == 2 and y_val.shape[1] == 1:
                y_val = y_val.flatten()
        else:
            y_val = None

        if verbosity <= 1:
            verbose_eval = False
        else:
            verbose_eval = True

        logger.log(15, "Neural network architecture:")
        logger.log(15, str(self.model))

        net_filename = self.path + self.temp_file_name
        if num_epochs == 0:
            # use dummy training loop that stops immediately
            # useful for using NN just for data preprocessing / debugging
            logger.log(20, "Not training Tabular Neural Network since num_updates == 0")

            # for each batch
            for batch_idx, data_batch in enumerate(train_dataloader):
                if batch_idx > 0:
                    break
                loss = self.model.compute_loss(data_batch, **loss_kwargs)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            os.makedirs(os.path.dirname(self.path), exist_ok=True)
            torch.save(self.model, net_filename)
            logger.log(15, "Untrained Tabular Neural Network saved to file")
            return

        # start training loop:
        logger.log(15, f"Training tabular neural network for up to {num_epochs} epochs...")
        total_updates = 0
        num_updates_per_epoch = len(train_dataloader)
        update_to_check_time = min(10, max(1, int(num_updates_per_epoch/10)))
        do_update = True
        epoch = 0
        best_epoch = 0
        best_val_metric = -np.inf  # higher = better
        best_val_update = 0
        val_improve_epoch = 0  # most recent epoch where validation-score strictly improved
        start_fit_time = time.time()
        if time_limit is not None:
            time_limit = time_limit - (start_fit_time - start_time)
            if time_limit <= 0:
                raise TimeLimitExceeded
        while do_update:
            time_start_epoch = time.time()
            total_train_loss = 0.0
            total_train_size = 0.0
            for batch_idx, data_batch in enumerate(train_dataloader):
                # forward
                loss = self.model.compute_loss(data_batch, **loss_kwargs)
                total_train_loss += loss.item()
                total_train_size += 1

                # update
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total_updates += 1

                # time limit
                if time_limit is not None:
                    time_cur = time.time()
                    update_cur = batch_idx + 1
                    if epoch == 0 and update_cur == update_to_check_time:
                        time_elapsed_epoch = time_cur - time_start_epoch
                        estimated_time = time_elapsed_epoch / update_cur * num_updates_per_epoch
                        if estimated_time > time_limit:
                            logger.log(30, f"\tNot enough time to train first epoch. "
                                           f"(Time Required: {round(estimated_time, 2)}s, Time Left: {round(time_limit, 2)}s)")
                            raise TimeLimitExceeded
                    time_elapsed = time_cur - start_fit_time
                    if time_limit < time_elapsed:
                        logger.log(15, f"\tRan out of time, stopping training early. (Stopped on Update {total_updates} (Epoch {epoch}))")
                        do_update = False
                        break

            if not do_update:
                break

            epoch += 1

            # validation
            if val_dataset is not None:
                # compute validation score
                val_metric = self.score(X=val_dataset, y=y_val, metric=self.stopping_metric)
                if np.isnan(val_metric):
                    if best_epoch == 0:
                        raise RuntimeError(f"NaNs encountered in {self.__class__.__name__} training. "
                                           "Features/labels may be improperly formatted, "
                                           "or NN weights may have diverged.")
                    else:
                        logger.warning(f"Warning: NaNs encountered in {self.__class__.__name__} training. "
                                       "Reverting model to last checkpoint without NaNs.")
                        break

                # update best validation
                if (val_metric >= best_val_metric) or best_epoch == 0:
                    if val_metric > best_val_metric:
                        val_improve_epoch = epoch
                    best_val_metric = val_metric
                    os.makedirs(os.path.dirname(self.path), exist_ok=True)
                    torch.save(self.model, net_filename)
                    best_epoch = epoch
                    best_val_update = total_updates
                if verbose_eval:
                    logger.log(15, f"Epoch {epoch} (Update {total_updates}).\t"
                                   f"Train loss: {round(total_train_loss / total_train_size, 4)}, "
                                   f"Val {self.stopping_metric.name}: {round(val_metric, 4)}, "
                                   f"Best Epoch: {best_epoch}")

                if reporter is not None:
                    reporter(epoch=total_updates,
                             validation_performance=val_metric,  # Higher val_metric = better
                             train_loss=total_train_loss / total_train_size,
                             eval_metric=self.eval_metric.name,
                             greater_is_better=self.eval_metric.greater_is_better)

                # no improvement
                if epoch - val_improve_epoch >= epochs_wo_improve:
                    break

            if epoch >= num_epochs:
                break

            if time_limit is not None:
                time_elapsed = time.time() - start_fit_time
                time_epoch_average = time_elapsed / (epoch+1)
                time_left = time_limit - time_elapsed
                if time_left < time_epoch_average:
                    logger.log(20, f"\tRan out of time, stopping training early. (Stopping on epoch {epoch})")
                    break

        if epoch == 0:
            raise AssertionError('0 epochs trained!')

        # revert back to best model
        if val_dataset is not None:
            logger.log(15, f"Best model found on Epoch {best_epoch} (Update {best_val_update}). Val {self.stopping_metric.name}: {best_val_metric}")
            try:
                self.model = torch.load(net_filename)
                os.remove(net_filename)
            except FileNotFoundError:
                pass
        else:
            logger.log(15, f"Best model found on Epoch {best_epoch} (Update {best_val_update}).")
        self.params_trained['batch_size'] = batch_size
        self.params_trained['num_epochs'] = best_epoch