def forward()

in models/feedback.py [0:0]


    def forward(self, x, h_cache, target=None):
        # x : B x M
        assert x.size(1) == self.args.mem_sz

        h0_block = self.in_emb(x)  # B x M x H

        h_cache = self.cache_initprocess(h_cache)

        for l in range(self.args.nlayers):
            self.get_layer(l).attn.attn.prepare_adapt_span()

        # Bring the temporal dimension to the front.
        h0_block = h0_block.transpose(0, 1).contiguous()
        h0_block = h0_block.unsqueeze(
            2
        )  # Add a dummy temporal dimension because the model expects that.
        h_out_block = []
        for t in range(self.args.mem_sz):
            key_cache, val_cache = self.cache_preprocess(h_cache, t)

            h_t = h0_block[t]
            h_t_all = []
            for l in range(self.args.nlayers):
                h_t_all.append(h_t)
                h_t = self.get_layer(l)(h_t, key_cache[l], val_cache[l])  # B x M x H
            h_t_out = h_t
            h_out_block.append(h_t_out)

            if self.args.feedback:
                h_t_all.append(h_t_out)
                h_t_all = self.merge_single_memory(h_t_all)

            h_cache = self.cache_postprocess(h_cache, h_t_all, t)

        h_out = torch.cat(h_out_block, dim=1)
        if self.args.pre_norm:
            h_out = self.out_norm(h_out)
        out = self.out(h_out, target)

        h_cache = self.cache_finalprocess(h_cache)

        return out, h_cache, 0