in models/temporal/neural.py [0:0]
def integrate_lambda(self, event_times, spatial_location, input_mask, t0, t1, nlinspace=1):
"""
Args:
event_times: (N, T)
spatial_location: (N, T, D)
input_mask: (N, T)
t0: (N,) or (1,)
t1: (N,) or (1,)
"""
N, T = event_times.shape
if not self.cond:
# disable dependence on spatial sample.
spatial_location = None
if input_mask is None:
input_mask = torch.ones_like(event_times)
input_mask = input_mask.bool()
init_state = self._init_state[None].expand(N, -1)
state = (
torch.zeros(N).to(init_state), # Lambda(t_0)
init_state,
)
t0 = t0 if torch.is_tensor(t0) else torch.tensor(t0)
t0 = t0.expand(N).to(event_times)
self.ode_solver.nfe = 0
intensities = []
prejump_hidden_states = []
for i in range(T):
# Set t1 = t0 if the input is masked out at time t1.
t1_i = torch.where(input_mask[:, i], event_times[:, i], t0)
state_traj = self.ode_solver.integrate(t0, t1_i, state, nlinspace=nlinspace, method="dopri5" if self.training else "dopri5")
hiddens = state_traj[1] # (1 + nlinspace, N, D)
if i > 0:
hiddens = hiddens[1:]
# set hidden states to zero if input is masked out at the next time step.
hiddens = torch.where(input_mask[:, i].reshape(1, -1, 1).expand_as(hiddens), hiddens, torch.zeros_like(hiddens))
prejump_hidden_states.append(hiddens)
state = tuple(s[-1] for s in state_traj)
Lambda, tpp_state = state
intensities.append(self.get_intensity(tpp_state).reshape(-1))
if i < T - 1 or t1 is not None:
cond = spatial_location[:, i] if spatial_location is not None else None
updated_tpp_state = self.hidden_state_dynamics.update_state(event_times[:, i], tpp_state, cond=cond)
tpp_state = torch.where(input_mask[:, i].reshape(-1, 1).expand_as(tpp_state), updated_tpp_state, tpp_state)
state = (Lambda, tpp_state)
# Track t0 as the last valid event time.
t0 = torch.where(input_mask[:, i], event_times[:, i], t0)
if t1 is not None:
# Integrate from last time sample to t1.
t1 = t1 if torch.is_tensor(t1) else torch.tensor(t1)
t1 = t1.expand(N).to(event_times)
state_traj = self.ode_solver.integrate(t0, t1, state, nlinspace=nlinspace, method="dopri5" if self.training else "dopri5")
hiddens = state_traj[1][1:]
prejump_hidden_states.append(hiddens)
state = tuple(s[-1] for s in state_traj)
Lambda, _ = state # (N,)
intensities = torch.stack(intensities, dim=1) # (N, T)
prejump_hidden_states = torch.cat(prejump_hidden_states, dim=0).transpose(0, 1) # (N, T * nlinspace, D)
return intensities, Lambda, prejump_hidden_states