in models/spatial/cnf.py [0:0]
def integrate(self, t0, t1, x, logpx, tol=None, method=None, norm=None, intermediate_states=0):
"""
Args:
t0: (N,)
t1: (N,)
x: (N, ...)
logpx: (N,)
"""
self.nfe = 0
tol = tol or self.tol
method = method or self.method
e = torch.randn_like(x)[:, :self.dim]
energy = torch.zeros(1).to(x)
jacnorm = torch.zeros(1).to(x)
initial_state = (t0, t1, e, x, logpx, energy, jacnorm)
if intermediate_states > 1:
tt = torch.linspace(self.start_time, self.end_time, intermediate_states).to(t0)
else:
tt = torch.tensor([self.start_time, self.end_time]).to(t0)
solution = odeint_adjoint(
self,
initial_state,
tt,
rtol=tol,
atol=tol,
method=method,
)
if intermediate_states > 1:
y = solution[3]
_, _, _, _, logpy, energy, jacnorm = tuple(s[-1] for s in solution)
else:
_, _, _, y, logpy, energy, jacnorm = tuple(s[-1] for s in solution)
regularization = (
self.energy_regularization * (energy - energy.detach()) +
self.jacnorm_regularization * (jacnorm - jacnorm.detach())
)
return y, logpy + regularization # hacky method to introduce regularization.