def _forward()

in model/mem_transformer.py [0:0]


    def _forward(self, dec_inp, reset_mems, mems=None, status_vec=None):

        qlen, bsz = dec_inp.size()[0], dec_inp.size()[1]
        word_emb = self.word_emb(dec_inp, status_vec)


        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen

        # Generate the mask between query and all the keys
        # TODO Think about how to enable masking when we reach BOS.
        if self.pad_type == "model":

            if self.same_length:
                all_ones = word_emb.new_ones(qlen, klen)
                mask_len = klen - self.mem_len
                if mask_len > 0:
                    mask_shift_len = qlen - mask_len
                else:
                    mask_shift_len = qlen

            # start_token = 1 if self.replace_start_with_pad else 0

            # if len(dec_inp.shape) == 2:
            #     indices = dec_inp[0] == start_token
            # elif len(dec_inp.shape) == 3:
            #     indices = dec_inp[0, :, 0] == start_token

            if reset_mems is None:
                indices = torch.BoolTensor(dec_inp.shape[1]).fill_(False)
            else:
                indices = reset_mems

            if self.same_length:
                dec_attn_mask = ((
                                         torch.triu(all_ones, 1 + mlen)
                                         + torch.tril(all_ones, -mask_shift_len)
                                 ).bool()[
                                 :, :
                                 ]).repeat(len(indices), 1, 1)  # -1
            else:
                dec_attn_mask = (torch.triu(
                    word_emb.new_ones(qlen, klen), diagonal=1 + mlen
                ).bool()[:, :]).repeat(len(indices), 1, 1)

            dec_attn_mask[indices, :, :mlen] = 1
        else:
            if self.same_length:
                all_ones = word_emb.new_ones(qlen, klen)
                mask_len = klen - self.mem_len
                if mask_len > 0:
                    mask_shift_len = qlen - mask_len
                else:
                    mask_shift_len = qlen
                dec_attn_mask = (
                                        torch.triu(all_ones, 1 + mlen)
                                        + torch.tril(all_ones, -mask_shift_len)
                                ).bool()[
                                :, :
                                ]  # -1
            else:
                dec_attn_mask = torch.triu(
                    word_emb.new_ones(qlen, klen), diagonal=1 + mlen
                ).bool()[:, :]

        hids = []
        pos_seq = torch.arange(
            klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype
        )
        if self.clamp_len > 0:
            pos_seq.clamp_(max=self.clamp_len)
        pos_emb = self.pos_emb(pos_seq)

        core_out = self.drop(word_emb)
        pos_emb = self.drop(pos_emb)

        hids.append(core_out)

        for i, layer in enumerate(self.layers):
            mems_i = None if mems is None else mems[i]
            core_out = layer(
                core_out,
                pos_emb,
                self.r_w_bias,
                self.r_r_bias,
                dec_attn_mask=dec_attn_mask,
                mems=mems_i,
            )
            hids.append(core_out)
        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen, reset_mems)
        return core_out, new_mems