def forward()

in models/spatial/cnf.py [0:0]


    def forward(self, s, state):
        """Solves the same dynamics but uses a dummy variable that always integrates [0, 1]."""
        self.nfe += 1
        t0, t1, e, x, logpx, _, _ = state

        ratio = (t1 - t0) / (self.end_time - self.start_time)
        t = (s - self.start_time) * ratio + t0

        vjp = None
        with torch.enable_grad():
            x = x.requires_grad_(True)

            dx = self.func(t, x)
            dx = dx * ratio.reshape(-1, *([1] * (x.ndim - 1)))

            if self.nonself_connections:
                dx_div = self.func(t, x, rm_nonself_grads=True)
                dx_div = dx_div * ratio.reshape(-1, *([1] * (x.ndim - 1)))
            else:
                dx_div = dx

            # Use brute force trace for testing.
            if not self.training:
                div = divergence_bf(dx_div[:, :self.dim], x, self.training)
            else:
                vjp = torch.autograd.grad(dx_div[:, :self.dim], x, e, create_graph=self.training, retain_graph=self.training)[0]
                vjp = vjp[:, :self.dim]
                div = torch.sum(vjp * e, dim=1)

            # Debugging code for checking gradient connections.
            # Need to send T and N to self from attncnf.
            # if self.training and hasattr(self, "T"):
            #     grads = torch.autograd.grad(dx_div.reshape(self.T, self.N, -1)[5, 0, 0], x, retain_graph=True)[0]
            #     print(grads.reshape(self.T, self.N, -1)[4:6, 0, :])

        if not self.training:
            dx = dx.detach()
            div = div.detach()

        d_energy = torch.sum(dx * dx).reshape(1) / x.shape[0]

        if self.training:
            d_jacnorm = torch.sum(vjp * vjp).reshape(1) / x.shape[0]
        else:
            d_jacnorm = torch.zeros(1).to(x)

        return torch.zeros_like(t0), torch.zeros_like(t1), torch.zeros_like(e), dx, -div, d_energy, d_jacnorm