in models/temporal/neural.py [0:0]
def __init__(self, cond_dim=0, hidden_dims=[64, 64, 64], cond=False, style="split", actfn="softplus", hdim=None, separate=1, tol=1e-6, otreg_strength=0.1):
super().__init__()
if not cond:
cond_dim = 0
self.cond = cond
self.cond_dim = cond_dim
self.hdim = hidden_dims[0] if hdim is None else hdim
assert self.hdim % 2 == 0
self._init_state = nn.Parameter(torch.randn(hidden_dims[0]) / math.sqrt(hidden_dims[0]))
dynamics = []
for i in range(separate):
dstate_net = construct_diffeqnet(hidden_dims[0] // separate, hidden_dims[1:], hidden_dims[0] // separate, time_dependent=False, actfn=actfn, zero_init=True)
if style in ["split", "simple"]:
update_net = construct_diffeqnet(hidden_dims[0] // separate + cond_dim, hidden_dims[1:], hidden_dims[0] // separate, time_dependent=False, actfn="celu", gated=True, zero_init=False)
elif style in ["gru"]:
update_net = nn.GRUCell(cond_dim, hidden_dims[0] // separate)
dynamics.append(self.dynamics_dict[style](dstate_net, update_net))
self.hidden_state_dynamics = HiddenStateODEFuncList(*dynamics)
intensity_net = nn.Sequential(
nn.Linear(self.hdim, self.hdim * 4),
nn.Softplus(),
nn.Linear(self.hdim * 4, 1),
)
intensity_odefunc = IntensityODEFunc(self.hdim, self.hidden_state_dynamics, intensity_net)
self.ode_solver = TimeVariableODE(intensity_odefunc, atol=tol, rtol=tol, energy_regularization=otreg_strength)