def __init__()

in models/spatial/attncnf.py [0:0]


    def __init__(self, dim, hidden_dims, aux_dim, actfn, time_offset, nblocks=2, l2_attn=False, layer_type="concat"):
        super().__init__()
        self.dim = dim
        self.aux_dim = aux_dim
        self.time_offset = time_offset

        mid_idx = int(math.ceil(len(hidden_dims) / 2))

        self.embed_dim = hidden_dims[mid_idx]
        self.embedding = build_fc_odefunc(
            self.dim + self.aux_dim, hidden_dims[:mid_idx],
            out_dim=self.embed_dim, layer_type=layer_type, actfn=actfn, zero_init=False)

        if l2_attn:
            mha = L2MultiheadAttention
        else:
            mha = MultiheadAttention

        self.self_attns = nn.ModuleList([mha(self.embed_dim, num_heads=4) for _ in range(nblocks)])
        self.attn_actnorms = nn.ModuleList([ActNorm(self.embed_dim) for _ in range(nblocks)])
        self.fcs = nn.ModuleList([
            nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim * 4), nn.Softplus(), nn.Linear(self.embed_dim * 4, self.embed_dim))
            for _ in range(nblocks)
        ])
        self.fc_actnorms = nn.ModuleList([ActNorm(self.embed_dim) for _ in range(nblocks)])
        self.attn_gates = nn.ModuleList([TanhGate() for _ in range(nblocks)])
        self.fc_gates = nn.ModuleList(TanhGate() for _ in range(nblocks))

        self.output_proj = build_fc_odefunc(self.embed_dim, hidden_dims[mid_idx:], out_dim=self.dim, layer_type=layer_type, actfn=actfn, zero_init=True)