models/spatial/attncnf.py [175:201]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def spatial_conditional_logprob_fn(self, t, event_times, spatial_locations, aux_state=None):
        """
        Args:
            t: scalar
            event_times: (T,)
            spatial_locations: (T, D)
            aux_state: (T + 1, D_a)

        Returns a function that takes locations (N, D) and returns (N,) the logprob at time t.
        """
        T, D = spatial_locations.shape

        def loglikelihood_fn(s):
            bsz = s.shape[0]
            bsz_event_times = event_times[None].expand(bsz, T)
            bsz_event_times = torch.cat([bsz_event_times, torch.ones(bsz, 1).to(bsz_event_times) * t], dim=1)
            bsz_spatial_locations = spatial_locations[None].expand(bsz, T, D)
            bsz_spatial_locations = torch.cat([bsz_spatial_locations, s.reshape(bsz, 1, D)], dim=1)

            if aux_state is not None:
                bsz_aux_state = aux_state.reshape(1, T + 1, -1).expand(bsz, -1, -1)
            else:
                bsz_aux_state = None

            return self.logprob(bsz_event_times, bsz_spatial_locations, input_mask=None, aux_state=bsz_aux_state).sum(1)

        return loglikelihood_fn
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/spatial/jumpcnf.py [140:166]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def spatial_conditional_logprob_fn(self, t, event_times, spatial_locations, aux_state=None):
        """
        Args:
            t: scalar
            event_times: (T,)
            spatial_locations: (T, D)
            aux_state: (T + 1, D_a)

        Returns a function that takes locations (N, D) and returns (N,) the logprob at time t.
        """
        T, D = spatial_locations.shape

        def loglikelihood_fn(s):
            bsz = s.shape[0]
            bsz_event_times = event_times[None].expand(bsz, T)
            bsz_event_times = torch.cat([bsz_event_times, torch.ones(bsz, 1).to(bsz_event_times) * t], dim=1)
            bsz_spatial_locations = spatial_locations[None].expand(bsz, T, D)
            bsz_spatial_locations = torch.cat([bsz_spatial_locations, s.reshape(bsz, 1, D)], dim=1)

            if aux_state is not None:
                bsz_aux_state = aux_state.reshape(1, T + 1, -1).expand(bsz, -1, -1)
            else:
                bsz_aux_state = None

            return self.logprob(bsz_event_times, bsz_spatial_locations, input_mask=None, aux_state=bsz_aux_state).sum(1)

        return loglikelihood_fn
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



