in models/spatial/jumpcnf.py [0:0]
def _cond_logliks(self, event_times, spatial_locations, input_mask=None, aux_state=None):
"""
Args:
event_times: (N, T)
spatial_locations: (N, T, D)
input_mask: (N, T) or None
aux_state: (N, T, D_a)
Returns:
A tensor of shape (N, T) containing the conditional log probabilities.
"""
if input_mask is None:
input_mask = torch.ones_like(event_times)
assert event_times.shape == input_mask.shape
assert event_times.shape[:2] == spatial_locations.shape[:2]
if aux_state is not None:
assert event_times.shape[:2] == aux_state.shape[:2]
N, T, D = spatial_locations.shape
self.cnf.nfe = 0
input_mask = input_mask.bool()
if aux_state is not None:
aux_state = aux_state
event_times = self.time_offset + event_times
event_times = torch.cat([torch.zeros(N, 1).to(event_times), event_times], dim=1) # (N, 1 + T)
input_mask = torch.cat([torch.ones(N, 1).to(input_mask), input_mask], dim=1)
for i in range(T):
# Mask out the integration if either t0 or t1 has input_mask == 0.
t0 = event_times[:, -i - 1].mul(input_mask[:, -i - 1]).mul(input_mask[:, -i - 2]).reshape(N, 1).expand(N, i + 1).reshape(-1)
t1 = event_times[:, -i - 2].mul(input_mask[:, -i - 1]).mul(input_mask[:, -i - 2]).reshape(N, 1).expand(N, i + 1).reshape(-1)
if i == 0:
xs = spatial_locations[:, -1].reshape(N, 1, D)
dlogps = torch.zeros(N, 1).to(xs)
else:
xs = torch.cat([
spatial_locations[:, -i - 1].reshape(N, 1, D),
xs,
], dim=1)
dlogps = torch.cat([
torch.zeros(N, 1).to(xs),
dlogps,
], dim=1)
xs = xs.reshape(-1, D)
dlogps = dlogps.reshape(-1)
norm_fn = None
if aux_state is not None:
D_a = aux_state.shape[-1]
auxs = aux_state[:, -i - 1:, :].expand(N, i + 1, D_a).reshape(-1, D_a)
inputs = [xs, auxs]
norm_fn = max_rms_norm([a.shape for a in inputs])
xs = torch.cat(inputs, dim=1)
xs, dlogps = self.cnf.integrate(t0, t1, xs, dlogps, method="dopri5" if i < T - 1 and self.training else "dopri5", norm=norm_fn)
xs, auxs = xs[:, :D], xs[:, D:]
# Apply instantaneous flow
if i < T - 1:
obs_x = spatial_locations[:, -i - 2].reshape(N, 1, D).expand(N, i + 1, D).reshape(-1, D)
obs_t = event_times[:, -i - 2].reshape(N, 1).expand(N, i + 1).reshape(-1, 1)
cond = torch.cat([obs_t, obs_x, auxs[:, -self.aux_dim:]], dim=1) # (N * (i + 1), 1 + D + D_a)
xs, dlogps = self.inst_flow(xs, logpx=dlogps, cond=cond)
xs = xs.reshape(N, i + 1, D)
dlogps = dlogps.reshape(N, i + 1)
dlogps = torch.where(input_mask[:, -i - 1:], dlogps, torch.zeros_like(dlogps))
logpz = gaussian_loglik(xs, self.z_mean.expand_as(xs), self.z_logstd.expand_as(xs)).sum(-1) # (N, T)
logpx = logpz - dlogps
return torch.where(input_mask[:, 1:], logpx, torch.zeros_like(logpx))