in models/spatial/cnf.py [0:0]
def build_fc_odefunc(dim=2, hidden_dims=[64, 64, 64], out_dim=None, nonzero_dim=None, actfn="softplus", layer_type="concatsquash",
zero_init=True, actfirst=False):
assert layer_type in LAYERTYPES.keys(), f"layer_type must be one of {LAYERTYPES.keys()} but was given {layer_type}"
layer_fn = LAYERTYPES[layer_type]
nonzero_dim = dim if nonzero_dim is None else nonzero_dim
out_dim = out_dim or dim
if hidden_dims:
dims = [dim] + list(hidden_dims)
layers = []
for d_in, d_out in zip(dims[:-1], dims[1:]):
layers.append(layer_fn(d_in, d_out))
layers.append(ACTFNS[actfn](d_out))
layers.append(layer_fn(hidden_dims[-1], out_dim))
else:
layers = [layer_fn(dim, out_dim)]
if actfirst and len(layers) > 1:
layers = layers[1:]
if nonzero_dim < dim:
# zero out weights for auxiliary inputs.
layers[0]._layer.weight.data[:, nonzero_dim:].fill_(0)
if zero_init:
for m in layers[-1].modules():
if isinstance(m, nn.Linear):
m.weight.data.fill_(0)
if m.bias is not None:
m.bias.data.fill_(0)
return diffeq_layers.SequentialDiffEq(*layers)