def __init__()

in models/temporal/neural.py [0:0]


    def __init__(self, cond_dim=0, hidden_dims=[64, 64, 64], cond=False, style="split", actfn="softplus", hdim=None, separate=1, tol=1e-6, otreg_strength=0.1):
        super().__init__()
        if not cond:
            cond_dim = 0
        self.cond = cond
        self.cond_dim = cond_dim
        self.hdim = hidden_dims[0] if hdim is None else hdim
        assert self.hdim % 2 == 0
        self._init_state = nn.Parameter(torch.randn(hidden_dims[0]) / math.sqrt(hidden_dims[0]))

        dynamics = []
        for i in range(separate):
            dstate_net = construct_diffeqnet(hidden_dims[0] // separate, hidden_dims[1:], hidden_dims[0] // separate, time_dependent=False, actfn=actfn, zero_init=True)
            if style in ["split", "simple"]:
                update_net = construct_diffeqnet(hidden_dims[0] // separate + cond_dim, hidden_dims[1:], hidden_dims[0] // separate, time_dependent=False, actfn="celu", gated=True, zero_init=False)
            elif style in ["gru"]:
                update_net = nn.GRUCell(cond_dim, hidden_dims[0] // separate)
            dynamics.append(self.dynamics_dict[style](dstate_net, update_net))

        self.hidden_state_dynamics = HiddenStateODEFuncList(*dynamics)

        intensity_net = nn.Sequential(
            nn.Linear(self.hdim, self.hdim * 4),
            nn.Softplus(),
            nn.Linear(self.hdim * 4, 1),
        )
        intensity_odefunc = IntensityODEFunc(self.hdim, self.hidden_state_dynamics, intensity_net)
        self.ode_solver = TimeVariableODE(intensity_odefunc, atol=tol, rtol=tol, energy_regularization=otreg_strength)