in models/feedback.py [0:0]
def forward(self, x, h_cache, target=None):
# x : B x M
assert x.size(1) == self.args.mem_sz
h0_block = self.in_emb(x) # B x M x H
h_cache = self.cache_initprocess(h_cache)
for l in range(self.args.nlayers):
self.get_layer(l).attn.attn.prepare_adapt_span()
# Bring the temporal dimension to the front.
h0_block = h0_block.transpose(0, 1).contiguous()
h0_block = h0_block.unsqueeze(
2
) # Add a dummy temporal dimension because the model expects that.
h_out_block = []
for t in range(self.args.mem_sz):
key_cache, val_cache = self.cache_preprocess(h_cache, t)
h_t = h0_block[t]
h_t_all = []
for l in range(self.args.nlayers):
h_t_all.append(h_t)
h_t = self.get_layer(l)(h_t, key_cache[l], val_cache[l]) # B x M x H
h_t_out = h_t
h_out_block.append(h_t_out)
if self.args.feedback:
h_t_all.append(h_t_out)
h_t_all = self.merge_single_memory(h_t_all)
h_cache = self.cache_postprocess(h_cache, h_t_all, t)
h_out = torch.cat(h_out_block, dim=1)
if self.args.pre_norm:
h_out = self.out_norm(h_out)
out = self.out(h_out, target)
h_cache = self.cache_finalprocess(h_cache)
return out, h_cache, 0