in src/model/transformer.py [0:0]
def fwd(self, x, lengths, causal, src_enc=None, src_len=None, positions=None, use_cache=False):
"""
Inputs:
`x` LongTensor(slen, bs), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(slen, bs), containing word positions
"""
# lengths = (x != self.pad_index).float().sum(dim=1)
# mask = x != self.pad_index
# check inputs
slen, bs = x.size()
assert lengths.size(0) == bs
assert lengths.max().item() <= slen
x = x.transpose(0, 1) # batch size as dimension 0
assert (src_enc is None) == (src_len is None)
if src_enc is not None:
assert self.is_decoder
assert src_enc.size(0) == bs
assert not (use_cache and self.cache is None)
# generate masks
mask, attn_mask = get_masks(slen, lengths, causal)
if self.is_decoder and src_enc is not None:
src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# positions
if positions is None:
positions = x.new(slen).long()
positions = torch.arange(slen, out=positions).unsqueeze(0)
else:
assert positions.size() == (slen, bs)
positions = positions.transpose(0, 1)
# do not recompute cached elements
if use_cache:
_slen = slen - self.cache['slen']
x = x[:, -_slen:]
positions = positions[:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# all layer outputs
if TransformerModel.STORE_OUTPUTS and not self.training:
self.outputs = []
# embeddings
tensor = self.embeddings(x)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
tensor = self.layer_norm_emb(tensor)
tensor = F.dropout(tensor, p=self.dropout, training=self.training)
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
if TransformerModel.STORE_OUTPUTS and not self.training:
self.outputs.append(tensor.detach().cpu())
# transformer layers
for i in range(self.n_layers):
# self attention
self.attentions[i].cache = self.cache
attn = self.attentions[i](tensor, attn_mask, use_cache=use_cache)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
# encoder attention (for decoder only)
if self.is_decoder and src_enc is not None:
self.encoder_attn[i].cache = self.cache
attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, use_cache=use_cache)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
tensor = self.layer_norm15[i](tensor)
# FFN
tensor = tensor + self.ffns[i](tensor)
tensor = self.layer_norm2[i](tensor)
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
if TransformerModel.STORE_OUTPUTS and not self.training:
self.outputs.append(tensor.detach().cpu())
# update cache length
if use_cache:
self.cache['slen'] += tensor.size(1)
# move back sequence length to dimension 0
tensor = tensor.transpose(0, 1)
return tensor