def forward()

in models/compressive.py [0:0]


    def forward(self, x, h_prev, target=None):
        # x : B x M
        B, M = x.size()
        H = self.args.hid_sz
        c = self.args.compress_rate
        C = self.args.compress_size // self.args.compress_rate

        h = self.in_emb(x)  # B x M x H

        c_prev = h_prev[-self.args.nlayers:]
        h_prev = h_prev[:-self.args.nlayers]

        h_cache = []
        c_cache = []
        aux_loss = 0
        for l in range(self.args.nlayers):
            h_memory = torch.cat([h_prev[l], h], dim=1)  # B x L+M x H

            # compress (note! there is overlap between two memories)
            new_compress = h_memory[:, :M, :]  # B x M x H
            new_compress = new_compress.view(B, M // c, c, H)
            new_compress = new_compress.mean(2)  # B x M/c x H
            c_memory = torch.cat([c_prev[l], new_compress], dim=1)  # B x C+M/c x H

            h, l = self.get_layer(l)(h, h_memory, c_memory)  # B x M x H
            aux_loss = aux_loss + l

            h_cache.append(h_memory[:, -self.args.attn_lim:, :])
            c_cache.append(c_memory[:, -C:, :])

        if self.args.pre_norm:
            h = self.out_norm(h)
        out = self.out(h, target)

        h_cache.extend(c_cache)

        return out, h_cache, aux_loss