in models/spatial/cnf.py [0:0]
def forward(self, s, state):
"""Solves the same dynamics but uses a dummy variable that always integrates [0, 1]."""
self.nfe += 1
t0, t1, e, x, logpx, _, _ = state
ratio = (t1 - t0) / (self.end_time - self.start_time)
t = (s - self.start_time) * ratio + t0
vjp = None
with torch.enable_grad():
x = x.requires_grad_(True)
dx = self.func(t, x)
dx = dx * ratio.reshape(-1, *([1] * (x.ndim - 1)))
if self.nonself_connections:
dx_div = self.func(t, x, rm_nonself_grads=True)
dx_div = dx_div * ratio.reshape(-1, *([1] * (x.ndim - 1)))
else:
dx_div = dx
# Use brute force trace for testing.
if not self.training:
div = divergence_bf(dx_div[:, :self.dim], x, self.training)
else:
vjp = torch.autograd.grad(dx_div[:, :self.dim], x, e, create_graph=self.training, retain_graph=self.training)[0]
vjp = vjp[:, :self.dim]
div = torch.sum(vjp * e, dim=1)
# Debugging code for checking gradient connections.
# Need to send T and N to self from attncnf.
# if self.training and hasattr(self, "T"):
# grads = torch.autograd.grad(dx_div.reshape(self.T, self.N, -1)[5, 0, 0], x, retain_graph=True)[0]
# print(grads.reshape(self.T, self.N, -1)[4:6, 0, :])
if not self.training:
dx = dx.detach()
div = div.detach()
d_energy = torch.sum(dx * dx).reshape(1) / x.shape[0]
if self.training:
d_jacnorm = torch.sum(vjp * vjp).reshape(1) / x.shape[0]
else:
d_jacnorm = torch.zeros(1).to(x)
return torch.zeros_like(t0), torch.zeros_like(t1), torch.zeros_like(e), dx, -div, d_energy, d_jacnorm