def _cond_logliks()

in models/spatial/jumpcnf.py [0:0]


    def _cond_logliks(self, event_times, spatial_locations, input_mask=None, aux_state=None):
        """
        Args:
            event_times: (N, T)
            spatial_locations: (N, T, D)
            input_mask: (N, T) or None
            aux_state: (N, T, D_a)

        Returns:
            A tensor of shape (N, T) containing the conditional log probabilities.
        """

        if input_mask is None:
            input_mask = torch.ones_like(event_times)

        assert event_times.shape == input_mask.shape
        assert event_times.shape[:2] == spatial_locations.shape[:2]
        if aux_state is not None:
            assert event_times.shape[:2] == aux_state.shape[:2]

        N, T, D = spatial_locations.shape
        self.cnf.nfe = 0

        input_mask = input_mask.bool()

        if aux_state is not None:
            aux_state = aux_state

        event_times = self.time_offset + event_times
        event_times = torch.cat([torch.zeros(N, 1).to(event_times), event_times], dim=1)  # (N, 1 + T)

        input_mask = torch.cat([torch.ones(N, 1).to(input_mask), input_mask], dim=1)

        for i in range(T):

            # Mask out the integration if either t0 or t1 has input_mask == 0.
            t0 = event_times[:, -i - 1].mul(input_mask[:, -i - 1]).mul(input_mask[:, -i - 2]).reshape(N, 1).expand(N, i + 1).reshape(-1)
            t1 = event_times[:, -i - 2].mul(input_mask[:, -i - 1]).mul(input_mask[:, -i - 2]).reshape(N, 1).expand(N, i + 1).reshape(-1)

            if i == 0:
                xs = spatial_locations[:, -1].reshape(N, 1, D)
                dlogps = torch.zeros(N, 1).to(xs)
            else:
                xs = torch.cat([
                    spatial_locations[:, -i - 1].reshape(N, 1, D),
                    xs,
                ], dim=1)
                dlogps = torch.cat([
                    torch.zeros(N, 1).to(xs),
                    dlogps,
                ], dim=1)

            xs = xs.reshape(-1, D)
            dlogps = dlogps.reshape(-1)

            norm_fn = None
            if aux_state is not None:
                D_a = aux_state.shape[-1]
                auxs = aux_state[:, -i - 1:, :].expand(N, i + 1, D_a).reshape(-1, D_a)
                inputs = [xs, auxs]
                norm_fn = max_rms_norm([a.shape for a in inputs])
                xs = torch.cat(inputs, dim=1)

            xs, dlogps = self.cnf.integrate(t0, t1, xs, dlogps, method="dopri5" if i < T - 1 and self.training else "dopri5", norm=norm_fn)

            xs, auxs = xs[:, :D], xs[:, D:]

            # Apply instantaneous flow
            if i < T - 1:
                obs_x = spatial_locations[:, -i - 2].reshape(N, 1, D).expand(N, i + 1, D).reshape(-1, D)
                obs_t = event_times[:, -i - 2].reshape(N, 1).expand(N, i + 1).reshape(-1, 1)
                cond = torch.cat([obs_t, obs_x, auxs[:, -self.aux_dim:]], dim=1)  # (N * (i + 1), 1 + D + D_a)
                xs, dlogps = self.inst_flow(xs, logpx=dlogps, cond=cond)

            xs = xs.reshape(N, i + 1, D)
            dlogps = dlogps.reshape(N, i + 1)
            dlogps = torch.where(input_mask[:, -i - 1:], dlogps, torch.zeros_like(dlogps))

        logpz = gaussian_loglik(xs, self.z_mean.expand_as(xs), self.z_logstd.expand_as(xs)).sum(-1)  # (N, T)
        logpx = logpz - dlogps

        return torch.where(input_mask[:, 1:], logpx, torch.zeros_like(logpx))