def _cond_logliks()

in models/spatial/attncnf.py [0:0]


    def _cond_logliks(self, event_times, spatial_locations, input_mask=None, aux_state=None):
        """
        Args:
            event_times: (N, T)
            spatial_locations: (N, T, D)
            input_mask: (N, T) or None
            aux_state: (N, T, D_a)

        Returns:
            A tensor of shape (N, T) containing the conditional log probabilities.
        """

        if input_mask is None:
            input_mask = torch.ones_like(event_times)

        assert event_times.shape == input_mask.shape
        assert event_times.shape[:2] == spatial_locations.shape[:2]
        if aux_state is not None:
            assert event_times.shape[:2] == aux_state.shape[:2]

        if aux_state is not None:
            aux_state = aux_state

        N, T, D = spatial_locations.shape
        spatial_locations = spatial_locations.clone().requires_grad_(True)

        t_embed = self.t_embedding(event_times) / math.sqrt(self.t_embedding_dim)

        if aux_state is not None:
            inputs = [spatial_locations, aux_state, t_embed]
        else:
            inputs = [spatial_locations, t_embed]

        # attention layer uses (T, N, D) ordering.
        inputs = [inp.transpose(0, 1) for inp in inputs]
        norm_fn = max_rms_norm([a.shape for a in inputs])
        x = torch.cat(inputs, dim=-1)

        self.odefunc.set_shape(x.shape)

        x = x.reshape(T * N, -1)
        event_times = event_times.transpose(0, 1).reshape(T * N)

        t0 = event_times + self.time_offset
        t1 = torch.zeros_like(event_times) + self.time_offset

        z, delta_logp = self.cnf.integrate(t0, t1, x, torch.zeros_like(event_times), norm=norm_fn)
        z = z[:, :self.dim]  # (T * N, D)

        base_t = torch.zeros_like(event_times)
        z, delta_logp = self.base_cnf.integrate(t1, base_t, z, delta_logp)

        if aux_state is not None:
            cond_inputs = [aux_state[:, :, -self.aux_dim:], t_embed]
        else:
            cond_inputs = [t_embed]
        cond = torch.cat(cond_inputs, dim=-1)  # (N, T, -1)
        cond = torch.where(input_mask[..., None].expand_as(cond).bool(), cond, torch.zeros_like(cond))
        cond = cond.transpose(0, 1).reshape(T * N, -1)

        z_params = self.base_dist_params(cond)

        z_mean, z_logstd = torch.split(z_params, D, dim=-1)

        logpz = gaussian_loglik(z, z_mean, z_logstd).sum(-1)
        logpx = logpz - delta_logp  # (T * N)

        logpx = logpx.reshape(T, N).transpose(0, 1)  # (N, T)

        return torch.where(input_mask.bool(), logpx, torch.zeros_like(logpx))