in vihds/training.py [0:0]
def _evaluate_elbo_and_plot(self, epoch, log_data, train_writer, valid_writer):
print("epoch %4d" % epoch, end="", flush=True)
log_data.n_test += 1
test_start = time.time()
plot = (self.args.plot_epoch > 0) and (np.mod(epoch, self.args.plot_epoch) == 0)
# Training
train_results, theta, q, p = self.model(
self.train_data, self.args.train_samples, writer=train_writer, epoch=epoch
)
train_output = self.cost(
self.train_data, train_results, theta, q, p, full_output=True, writer=train_writer, epoch=epoch,
)
print(
" | train (iwae-elbo = %0.4f, time = %0.2f, total = %0.2f)"
% (train_output.elbo, log_data.total_train_time / epoch, log_data.total_train_time,),
end="",
flush=True,
)
if train_writer is not None:
if plot:
self._plot_prediction_summary(self.train_data, train_output, epoch, train_writer)
# self._plot_species(self.train_data, train_output, epoch, train_writer)
if self.model.decoder.ode_model.precisions.dynamic:
self._plot_variance(self.train_data, train_output, epoch, train_writer)
train_writer.flush()
# Validation
valid_results, theta, q, p = self.model(
self.valid_data, self.args.test_samples, writer=valid_writer, epoch=epoch
)
valid_output = self.cost(
self.valid_data, valid_results, theta, q, p, full_output=True, writer=valid_writer, epoch=epoch,
)
if valid_writer is not None:
if plot:
self._plot_prediction_summary(self.valid_data, valid_output, epoch, valid_writer)
# self._plot_species(self.valid_data, valid_output, epoch, valid_writer)
if self.model.decoder.ode_model.precisions.dynamic:
self._plot_variance(self.valid_data, valid_output, epoch, valid_writer)
valid_writer.flush()
log_data.total_test_time += time.time() - test_start
print(
" | val (iwae-elbo = %0.4f, time = %0.2f, total = %0.2f)"
% (valid_output.elbo, log_data.total_test_time / log_data.n_test, log_data.total_test_time,)
)
if valid_output.elbo > log_data.max_val_elbo:
log_data.max_val_elbo = valid_output.elbo
valid_output.dump()
self.empty_cache = False
log_data.training_elbo_list.append(train_output.elbo)
log_data.validation_elbo_list.append(valid_output.elbo)
return valid_output