in model/mem_transformer.py [0:0]
def _forward(self, dec_inp, reset_mems, mems=None, status_vec=None):
qlen, bsz = dec_inp.size()[0], dec_inp.size()[1]
word_emb = self.word_emb(dec_inp, status_vec)
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
# Generate the mask between query and all the keys
# TODO Think about how to enable masking when we reach BOS.
if self.pad_type == "model":
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
# start_token = 1 if self.replace_start_with_pad else 0
# if len(dec_inp.shape) == 2:
# indices = dec_inp[0] == start_token
# elif len(dec_inp.shape) == 3:
# indices = dec_inp[0, :, 0] == start_token
if reset_mems is None:
indices = torch.BoolTensor(dec_inp.shape[1]).fill_(False)
else:
indices = reset_mems
if self.same_length:
dec_attn_mask = ((
torch.triu(all_ones, 1 + mlen)
+ torch.tril(all_ones, -mask_shift_len)
).bool()[
:, :
]).repeat(len(indices), 1, 1) # -1
else:
dec_attn_mask = (torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1 + mlen
).bool()[:, :]).repeat(len(indices), 1, 1)
dec_attn_mask[indices, :, :mlen] = 1
else:
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (
torch.triu(all_ones, 1 + mlen)
+ torch.tril(all_ones, -mask_shift_len)
).bool()[
:, :
] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1 + mlen
).bool()[:, :]
hids = []
pos_seq = torch.arange(
klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype
)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
core_out = layer(
core_out,
pos_emb,
self.r_w_bias,
self.r_r_bias,
dec_attn_mask=dec_attn_mask,
mems=mems_i,
)
hids.append(core_out)
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen, reset_mems)
return core_out, new_mems