in models/spatial/attncnf.py [0:0]
def _cond_logliks(self, event_times, spatial_locations, input_mask=None, aux_state=None):
"""
Args:
event_times: (N, T)
spatial_locations: (N, T, D)
input_mask: (N, T) or None
aux_state: (N, T, D_a)
Returns:
A tensor of shape (N, T) containing the conditional log probabilities.
"""
if input_mask is None:
input_mask = torch.ones_like(event_times)
assert event_times.shape == input_mask.shape
assert event_times.shape[:2] == spatial_locations.shape[:2]
if aux_state is not None:
assert event_times.shape[:2] == aux_state.shape[:2]
if aux_state is not None:
aux_state = aux_state
N, T, D = spatial_locations.shape
spatial_locations = spatial_locations.clone().requires_grad_(True)
t_embed = self.t_embedding(event_times) / math.sqrt(self.t_embedding_dim)
if aux_state is not None:
inputs = [spatial_locations, aux_state, t_embed]
else:
inputs = [spatial_locations, t_embed]
# attention layer uses (T, N, D) ordering.
inputs = [inp.transpose(0, 1) for inp in inputs]
norm_fn = max_rms_norm([a.shape for a in inputs])
x = torch.cat(inputs, dim=-1)
self.odefunc.set_shape(x.shape)
x = x.reshape(T * N, -1)
event_times = event_times.transpose(0, 1).reshape(T * N)
t0 = event_times + self.time_offset
t1 = torch.zeros_like(event_times) + self.time_offset
z, delta_logp = self.cnf.integrate(t0, t1, x, torch.zeros_like(event_times), norm=norm_fn)
z = z[:, :self.dim] # (T * N, D)
base_t = torch.zeros_like(event_times)
z, delta_logp = self.base_cnf.integrate(t1, base_t, z, delta_logp)
if aux_state is not None:
cond_inputs = [aux_state[:, :, -self.aux_dim:], t_embed]
else:
cond_inputs = [t_embed]
cond = torch.cat(cond_inputs, dim=-1) # (N, T, -1)
cond = torch.where(input_mask[..., None].expand_as(cond).bool(), cond, torch.zeros_like(cond))
cond = cond.transpose(0, 1).reshape(T * N, -1)
z_params = self.base_dist_params(cond)
z_mean, z_logstd = torch.split(z_params, D, dim=-1)
logpz = gaussian_loglik(z, z_mean, z_logstd).sum(-1)
logpx = logpz - delta_logp # (T * N)
logpx = logpx.reshape(T, N).transpose(0, 1) # (N, T)
return torch.where(input_mask.bool(), logpx, torch.zeros_like(logpx))