in models/temporal/neural.py [0:0]
def forward(self, s, state):
"""Solves the same dynamics but uses a dummy variable that always integrates [0, 1]."""
self.nfe += 1
t0, t1, _, *x = state
ratio = (t1 - t0) / (self.end_time - self.start_time)
t = (s - self.start_time) * ratio + t0
with torch.enable_grad():
x = tuple(x_.requires_grad_(True) for x_ in x)
dx = self.func(t, x)
dx = tuple(dx_ * ratio.reshape(-1, *([1] * (dx_.ndim - 1))) for dx_ in dx)
d_energy = sum(torch.sum(dx_ * dx_) for dx_ in dx) / sum(x_.numel() for x_ in x)
if not self.training:
dx = tuple(dx_.detach() for dx_ in dx)
return tuple([torch.zeros_like(t0), torch.zeros_like(t1), d_energy, *dx])