def build_fc_odefunc()

in models/spatial/cnf.py [0:0]


def build_fc_odefunc(dim=2, hidden_dims=[64, 64, 64], out_dim=None, nonzero_dim=None, actfn="softplus", layer_type="concatsquash",
                     zero_init=True, actfirst=False):
    assert layer_type in LAYERTYPES.keys(), f"layer_type must be one of {LAYERTYPES.keys()} but was given {layer_type}"
    layer_fn = LAYERTYPES[layer_type]

    nonzero_dim = dim if nonzero_dim is None else nonzero_dim
    out_dim = out_dim or dim
    if hidden_dims:
        dims = [dim] + list(hidden_dims)
        layers = []
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers.append(layer_fn(d_in, d_out))
            layers.append(ACTFNS[actfn](d_out))
        layers.append(layer_fn(hidden_dims[-1], out_dim))
    else:
        layers = [layer_fn(dim, out_dim)]

    if actfirst and len(layers) > 1:
        layers = layers[1:]

    if nonzero_dim < dim:
        # zero out weights for auxiliary inputs.
        layers[0]._layer.weight.data[:, nonzero_dim:].fill_(0)

    if zero_init:
        for m in layers[-1].modules():
            if isinstance(m, nn.Linear):
                m.weight.data.fill_(0)
                if m.bias is not None:
                    m.bias.data.fill_(0)

    return diffeq_layers.SequentialDiffEq(*layers)