in models.py [0:0]
def forward(self, x, h_cache, target=None):
# x size = B x M
block_size = x.size(1)
h = self.in_emb(x) # B x M x H
if self.emb_dropout is not None:
h = self.emb_dropout(h)
h_cache_next = []
for l, layer in enumerate(self.layers):
cache_size = layer.attn.attn.get_cache_size()
if cache_size > block_size:
h_cache_next_l = torch.cat(
[h_cache[l][:, -cache_size + block_size:, :], h],
dim=1).detach()
else:
h_cache_next_l = h[:, -cache_size:, :].detach()
h_cache_next.append(h_cache_next_l)
h = layer(h, h_cache[l], self.key_pe) # B x M x H
if self.emb_dropout is not None:
h = self.emb_dropout(h)
if self.adapt_io:
# loss is computed here
out = self.out_emb(h, target)
dummy_loss = compute_dummy_loss(self.in_emb, self.out_emb)
else:
out = F.log_softmax(self.out_emb(h), dim=-1)
dummy_loss = None
return out, h_cache_next, dummy_loss