def integrate()

in models/spatial/cnf.py [0:0]


    def integrate(self, t0, t1, x, logpx, tol=None, method=None, norm=None, intermediate_states=0):
        """
        Args:
            t0: (N,)
            t1: (N,)
            x: (N, ...)
            logpx: (N,)
        """
        self.nfe = 0

        tol = tol or self.tol
        method = method or self.method
        e = torch.randn_like(x)[:, :self.dim]
        energy = torch.zeros(1).to(x)
        jacnorm = torch.zeros(1).to(x)
        initial_state = (t0, t1, e, x, logpx, energy, jacnorm)

        if intermediate_states > 1:
            tt = torch.linspace(self.start_time, self.end_time, intermediate_states).to(t0)
        else:
            tt = torch.tensor([self.start_time, self.end_time]).to(t0)

        solution = odeint_adjoint(
            self,
            initial_state,
            tt,
            rtol=tol,
            atol=tol,
            method=method,
        )

        if intermediate_states > 1:
            y = solution[3]
            _, _, _, _, logpy, energy, jacnorm = tuple(s[-1] for s in solution)
        else:
            _, _, _, y, logpy, energy, jacnorm = tuple(s[-1] for s in solution)

        regularization = (
            self.energy_regularization * (energy - energy.detach()) +
            self.jacnorm_regularization * (jacnorm - jacnorm.detach())
        )

        return y, logpy + regularization  # hacky method to introduce regularization.