in vihds/training.py [0:0]
def run(self):
# Tensorboard writers
if self.settings.trainer is not None:
train_writer = SummaryWriter(self.train_path)
valid_writer = SummaryWriter(self.valid_path)
else:
train_writer = None
valid_writer = None
log_data = TrainingLogData()
print("---------------------------")
if self.args.heldout:
split_name = "heldout device = %s" % self.args.heldout
else:
split_name = "split %d of %d" % (self.args.split, self.args.folds)
print("Training: %s" % split_name)
iterating = True
epoch = 1
while iterating is True and (epoch < self.args.epochs + 1):
self.model.train()
epoch_start = time.time()
for batch in self.train_loader:
if iterating:
iterating = self._run_batch(epoch_start, batch, log_data)
log_data.total_train_time += time.time() - epoch_start
# Occasionally evaluate ELBO on train and val, using more IW samples
if iterating and (np.mod(epoch, self.args.test_epoch) == 0):
self.model.eval()
valid_output = self._evaluate_elbo_and_plot(epoch, log_data, train_writer, valid_writer)
self.scheduler.step()
epoch += 1
if self.settings.trainer is not None:
train_writer.close()
valid_writer.close()
# Reload results from best validation elbo score
if self.empty_cache:
print("Exiting with no results in cache")
return None
valid_output.load()
valid_output.elbo_list = log_data.validation_elbo_list
return valid_output