def _evaluate_elbo_and_plot()

in vihds/training.py [0:0]


    def _evaluate_elbo_and_plot(self, epoch, log_data, train_writer, valid_writer):
        print("epoch %4d" % epoch, end="", flush=True)
        log_data.n_test += 1
        test_start = time.time()
        plot = (self.args.plot_epoch > 0) and (np.mod(epoch, self.args.plot_epoch) == 0)

        # Training
        train_results, theta, q, p = self.model(
            self.train_data, self.args.train_samples, writer=train_writer, epoch=epoch
        )
        train_output = self.cost(
            self.train_data, train_results, theta, q, p, full_output=True, writer=train_writer, epoch=epoch,
        )
        print(
            " | train (iwae-elbo = %0.4f, time = %0.2f, total = %0.2f)"
            % (train_output.elbo, log_data.total_train_time / epoch, log_data.total_train_time,),
            end="",
            flush=True,
        )
        if train_writer is not None:
            if plot:
                self._plot_prediction_summary(self.train_data, train_output, epoch, train_writer)
                # self._plot_species(self.train_data, train_output, epoch, train_writer)
                if self.model.decoder.ode_model.precisions.dynamic:
                    self._plot_variance(self.train_data, train_output, epoch, train_writer)
            train_writer.flush()

        # Validation
        valid_results, theta, q, p = self.model(
            self.valid_data, self.args.test_samples, writer=valid_writer, epoch=epoch
        )
        valid_output = self.cost(
            self.valid_data, valid_results, theta, q, p, full_output=True, writer=valid_writer, epoch=epoch,
        )
        if valid_writer is not None:
            if plot:
                self._plot_prediction_summary(self.valid_data, valid_output, epoch, valid_writer)
                # self._plot_species(self.valid_data, valid_output, epoch, valid_writer)
                if self.model.decoder.ode_model.precisions.dynamic:
                    self._plot_variance(self.valid_data, valid_output, epoch, valid_writer)
            valid_writer.flush()
        log_data.total_test_time += time.time() - test_start
        print(
            " | val (iwae-elbo = %0.4f, time = %0.2f, total = %0.2f)"
            % (valid_output.elbo, log_data.total_test_time / log_data.n_test, log_data.total_test_time,)
        )

        if valid_output.elbo > log_data.max_val_elbo:
            log_data.max_val_elbo = valid_output.elbo
            valid_output.dump()
            self.empty_cache = False

        log_data.training_elbo_list.append(train_output.elbo)
        log_data.validation_elbo_list.append(valid_output.elbo)

        return valid_output