in tabular/src/autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py [0:0]
def _train_net(self,
train_dataset,
loss_kwargs,
batch_size,
num_epochs,
epochs_wo_improve,
val_dataset=None,
time_limit=None,
reporter=None,
verbosity=2):
import torch
start_time = time.time()
logging.debug("initializing neural network...")
self.model.init_params()
logging.debug("initialized")
train_dataloader = train_dataset.build_loader(batch_size, self.num_dataloading_workers, is_test=False)
if isinstance(loss_kwargs.get('loss_function', 'auto'), str) and loss_kwargs.get('loss_function', 'auto') == 'auto':
loss_kwargs['loss_function'] = self._get_default_loss_function()
if val_dataset is not None:
y_val = val_dataset.get_labels()
if y_val.ndim == 2 and y_val.shape[1] == 1:
y_val = y_val.flatten()
else:
y_val = None
if verbosity <= 1:
verbose_eval = False
else:
verbose_eval = True
logger.log(15, "Neural network architecture:")
logger.log(15, str(self.model))
net_filename = self.path + self.temp_file_name
if num_epochs == 0:
# use dummy training loop that stops immediately
# useful for using NN just for data preprocessing / debugging
logger.log(20, "Not training Tabular Neural Network since num_updates == 0")
# for each batch
for batch_idx, data_batch in enumerate(train_dataloader):
if batch_idx > 0:
break
loss = self.model.compute_loss(data_batch, **loss_kwargs)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
os.makedirs(os.path.dirname(self.path), exist_ok=True)
torch.save(self.model, net_filename)
logger.log(15, "Untrained Tabular Neural Network saved to file")
return
# start training loop:
logger.log(15, f"Training tabular neural network for up to {num_epochs} epochs...")
total_updates = 0
num_updates_per_epoch = len(train_dataloader)
update_to_check_time = min(10, max(1, int(num_updates_per_epoch/10)))
do_update = True
epoch = 0
best_epoch = 0
best_val_metric = -np.inf # higher = better
best_val_update = 0
val_improve_epoch = 0 # most recent epoch where validation-score strictly improved
start_fit_time = time.time()
if time_limit is not None:
time_limit = time_limit - (start_fit_time - start_time)
if time_limit <= 0:
raise TimeLimitExceeded
while do_update:
time_start_epoch = time.time()
total_train_loss = 0.0
total_train_size = 0.0
for batch_idx, data_batch in enumerate(train_dataloader):
# forward
loss = self.model.compute_loss(data_batch, **loss_kwargs)
total_train_loss += loss.item()
total_train_size += 1
# update
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_updates += 1
# time limit
if time_limit is not None:
time_cur = time.time()
update_cur = batch_idx + 1
if epoch == 0 and update_cur == update_to_check_time:
time_elapsed_epoch = time_cur - time_start_epoch
estimated_time = time_elapsed_epoch / update_cur * num_updates_per_epoch
if estimated_time > time_limit:
logger.log(30, f"\tNot enough time to train first epoch. "
f"(Time Required: {round(estimated_time, 2)}s, Time Left: {round(time_limit, 2)}s)")
raise TimeLimitExceeded
time_elapsed = time_cur - start_fit_time
if time_limit < time_elapsed:
logger.log(15, f"\tRan out of time, stopping training early. (Stopped on Update {total_updates} (Epoch {epoch}))")
do_update = False
break
if not do_update:
break
epoch += 1
# validation
if val_dataset is not None:
# compute validation score
val_metric = self.score(X=val_dataset, y=y_val, metric=self.stopping_metric)
if np.isnan(val_metric):
if best_epoch == 0:
raise RuntimeError(f"NaNs encountered in {self.__class__.__name__} training. "
"Features/labels may be improperly formatted, "
"or NN weights may have diverged.")
else:
logger.warning(f"Warning: NaNs encountered in {self.__class__.__name__} training. "
"Reverting model to last checkpoint without NaNs.")
break
# update best validation
if (val_metric >= best_val_metric) or best_epoch == 0:
if val_metric > best_val_metric:
val_improve_epoch = epoch
best_val_metric = val_metric
os.makedirs(os.path.dirname(self.path), exist_ok=True)
torch.save(self.model, net_filename)
best_epoch = epoch
best_val_update = total_updates
if verbose_eval:
logger.log(15, f"Epoch {epoch} (Update {total_updates}).\t"
f"Train loss: {round(total_train_loss / total_train_size, 4)}, "
f"Val {self.stopping_metric.name}: {round(val_metric, 4)}, "
f"Best Epoch: {best_epoch}")
if reporter is not None:
reporter(epoch=total_updates,
validation_performance=val_metric, # Higher val_metric = better
train_loss=total_train_loss / total_train_size,
eval_metric=self.eval_metric.name,
greater_is_better=self.eval_metric.greater_is_better)
# no improvement
if epoch - val_improve_epoch >= epochs_wo_improve:
break
if epoch >= num_epochs:
break
if time_limit is not None:
time_elapsed = time.time() - start_fit_time
time_epoch_average = time_elapsed / (epoch+1)
time_left = time_limit - time_elapsed
if time_left < time_epoch_average:
logger.log(20, f"\tRan out of time, stopping training early. (Stopping on epoch {epoch})")
break
if epoch == 0:
raise AssertionError('0 epochs trained!')
# revert back to best model
if val_dataset is not None:
logger.log(15, f"Best model found on Epoch {best_epoch} (Update {best_val_update}). Val {self.stopping_metric.name}: {best_val_metric}")
try:
self.model = torch.load(net_filename)
os.remove(net_filename)
except FileNotFoundError:
pass
else:
logger.log(15, f"Best model found on Epoch {best_epoch} (Update {best_val_update}).")
self.params_trained['batch_size'] = batch_size
self.params_trained['num_epochs'] = best_epoch