def forward()

in models.py [0:0]


    def forward(self, x, h_cache, target=None):
        # x size = B x M
        block_size = x.size(1)
        h = self.in_emb(x)  # B x M x H
        if self.emb_dropout is not None:
            h = self.emb_dropout(h)

        h_cache_next = []
        for l, layer in enumerate(self.layers):
            cache_size = layer.attn.attn.get_cache_size()
            if cache_size > block_size:
                h_cache_next_l = torch.cat(
                    [h_cache[l][:, -cache_size + block_size:, :], h],
                    dim=1).detach()
            else:
                h_cache_next_l = h[:, -cache_size:, :].detach()
            h_cache_next.append(h_cache_next_l)
            h = layer(h, h_cache[l], self.key_pe)  # B x M x H

        if self.emb_dropout is not None:
            h = self.emb_dropout(h)
        if self.adapt_io:
            # loss is computed here
            out = self.out_emb(h, target)
            dummy_loss = compute_dummy_loss(self.in_emb, self.out_emb)
        else:
            out = F.log_softmax(self.out_emb(h), dim=-1)
            dummy_loss = None

        return out, h_cache_next, dummy_loss