in models/compressive.py [0:0]
def forward(self, x, h_prev, target=None):
# x : B x M
B, M = x.size()
H = self.args.hid_sz
c = self.args.compress_rate
C = self.args.compress_size // self.args.compress_rate
h = self.in_emb(x) # B x M x H
c_prev = h_prev[-self.args.nlayers:]
h_prev = h_prev[:-self.args.nlayers]
h_cache = []
c_cache = []
aux_loss = 0
for l in range(self.args.nlayers):
h_memory = torch.cat([h_prev[l], h], dim=1) # B x L+M x H
# compress (note! there is overlap between two memories)
new_compress = h_memory[:, :M, :] # B x M x H
new_compress = new_compress.view(B, M // c, c, H)
new_compress = new_compress.mean(2) # B x M/c x H
c_memory = torch.cat([c_prev[l], new_compress], dim=1) # B x C+M/c x H
h, l = self.get_layer(l)(h, h_memory, c_memory) # B x M x H
aux_loss = aux_loss + l
h_cache.append(h_memory[:, -self.args.attn_lim:, :])
c_cache.append(c_memory[:, -C:, :])
if self.args.pre_norm:
h = self.out_norm(h)
out = self.out(h, target)
h_cache.extend(c_cache)
return out, h_cache, aux_loss