def integrate_lambda()

in models/temporal/neural.py [0:0]


    def integrate_lambda(self, event_times, spatial_location, input_mask, t0, t1, nlinspace=1):
        """
        Args:
            event_times: (N, T)
            spatial_location: (N, T, D)
            input_mask: (N, T)
            t0: (N,) or (1,)
            t1: (N,) or (1,)
        """
        N, T = event_times.shape

        if not self.cond:
            # disable dependence on spatial sample.
            spatial_location = None

        if input_mask is None:
            input_mask = torch.ones_like(event_times)

        input_mask = input_mask.bool()

        init_state = self._init_state[None].expand(N, -1)
        state = (
            torch.zeros(N).to(init_state),  # Lambda(t_0)
            init_state,
        )

        t0 = t0 if torch.is_tensor(t0) else torch.tensor(t0)
        t0 = t0.expand(N).to(event_times)

        self.ode_solver.nfe = 0
        intensities = []
        prejump_hidden_states = []
        for i in range(T):

            # Set t1 = t0 if the input is masked out at time t1.
            t1_i = torch.where(input_mask[:, i], event_times[:, i], t0)
            state_traj = self.ode_solver.integrate(t0, t1_i, state, nlinspace=nlinspace, method="dopri5" if self.training else "dopri5")

            hiddens = state_traj[1]  # (1 + nlinspace, N, D)
            if i > 0:
                hiddens = hiddens[1:]
            # set hidden states to zero if input is masked out at the next time step.
            hiddens = torch.where(input_mask[:, i].reshape(1, -1, 1).expand_as(hiddens), hiddens, torch.zeros_like(hiddens))
            prejump_hidden_states.append(hiddens)

            state = tuple(s[-1] for s in state_traj)
            Lambda, tpp_state = state
            intensities.append(self.get_intensity(tpp_state).reshape(-1))

            if i < T - 1 or t1 is not None:
                cond = spatial_location[:, i] if spatial_location is not None else None
                updated_tpp_state = self.hidden_state_dynamics.update_state(event_times[:, i], tpp_state, cond=cond)
                tpp_state = torch.where(input_mask[:, i].reshape(-1, 1).expand_as(tpp_state), updated_tpp_state, tpp_state)
                state = (Lambda, tpp_state)

            # Track t0 as the last valid event time.
            t0 = torch.where(input_mask[:, i], event_times[:, i], t0)

        if t1 is not None:
            # Integrate from last time sample to t1.
            t1 = t1 if torch.is_tensor(t1) else torch.tensor(t1)
            t1 = t1.expand(N).to(event_times)
            state_traj = self.ode_solver.integrate(t0, t1, state, nlinspace=nlinspace, method="dopri5" if self.training else "dopri5")

            hiddens = state_traj[1][1:]
            prejump_hidden_states.append(hiddens)

            state = tuple(s[-1] for s in state_traj)

        Lambda, _ = state  # (N,)
        intensities = torch.stack(intensities, dim=1)  # (N, T)
        prejump_hidden_states = torch.cat(prejump_hidden_states, dim=0).transpose(0, 1)  # (N, T * nlinspace, D)
        return intensities, Lambda, prejump_hidden_states