in src/model/transformer.py [0:0]
def fwd(self, x, lengths, causal, src_enc=None, src_len=None, positions=None, langs=None, cache=None):
"""
Inputs:
`x` LongTensor(slen, bs), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(slen, bs), containing word positions
`langs` LongTensor(slen, bs), containing language IDs
"""
# lengths = (x != self.pad_index).float().sum(dim=1)
# mask = x != self.pad_index
# check inputs
slen, bs = x.size()
assert lengths.size(0) == bs
assert lengths.max().item() <= slen
x = x.transpose(0, 1) # batch size as dimension 0
assert (src_enc is None) == (src_len is None)
if src_enc is not None:
assert self.is_decoder
assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, lengths, causal)
if self.is_decoder and src_enc is not None:
src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# positions
if positions is None:
positions = x.new(slen).long()
positions = torch.arange(slen, out=positions).unsqueeze(0)
else:
assert positions.size() == (slen, bs)
positions = positions.transpose(0, 1)
# langs
if langs is not None:
assert langs.size() == (slen, bs)
langs = langs.transpose(0, 1)
# do not recompute cached elements
if cache is not None:
_slen = slen - cache['slen']
x = x[:, -_slen:]
positions = positions[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# embeddings
tensor = self.embeddings(x)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs)
tensor = self.layer_norm_emb(tensor)
tensor = F.dropout(tensor, p=self.dropout, training=self.training)
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers
for i in range(self.n_layers):
# self attention
attn = self.attentions[i](tensor, attn_mask, cache=cache)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
# encoder attention (for decoder only)
if self.is_decoder and src_enc is not None:
attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
tensor = self.layer_norm15[i](tensor)
# FFN
if ('%i_in' % i) in self.memories:
tensor = tensor + self.memories['%i_in' % i](tensor)
else:
tensor = tensor + self.ffns[i](tensor)
tensor = self.layer_norm2[i](tensor)
# memory
if ('%i_after' % i) in self.memories:
tensor = tensor + self.memories['%i_after' % i](tensor)
# TODO: add extra layer norm here?
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# update cache length
if cache is not None:
cache['slen'] += tensor.size(1)
# move back sequence length to dimension 0
tensor = tensor.transpose(0, 1)
return tensor