in models/spatial/attncnf.py [0:0]
def __init__(self, dim, hidden_dims, aux_dim, actfn, time_offset, nblocks=2, l2_attn=False, layer_type="concat"):
super().__init__()
self.dim = dim
self.aux_dim = aux_dim
self.time_offset = time_offset
mid_idx = int(math.ceil(len(hidden_dims) / 2))
self.embed_dim = hidden_dims[mid_idx]
self.embedding = build_fc_odefunc(
self.dim + self.aux_dim, hidden_dims[:mid_idx],
out_dim=self.embed_dim, layer_type=layer_type, actfn=actfn, zero_init=False)
if l2_attn:
mha = L2MultiheadAttention
else:
mha = MultiheadAttention
self.self_attns = nn.ModuleList([mha(self.embed_dim, num_heads=4) for _ in range(nblocks)])
self.attn_actnorms = nn.ModuleList([ActNorm(self.embed_dim) for _ in range(nblocks)])
self.fcs = nn.ModuleList([
nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim * 4), nn.Softplus(), nn.Linear(self.embed_dim * 4, self.embed_dim))
for _ in range(nblocks)
])
self.fc_actnorms = nn.ModuleList([ActNorm(self.embed_dim) for _ in range(nblocks)])
self.attn_gates = nn.ModuleList([TanhGate() for _ in range(nblocks)])
self.fc_gates = nn.ModuleList(TanhGate() for _ in range(nblocks))
self.output_proj = build_fc_odefunc(self.embed_dim, hidden_dims[mid_idx:], out_dim=self.dim, layer_type=layer_type, actfn=actfn, zero_init=True)