in src/modules/multihead_attention.py [0:0]
def forward(self,
query,
key,
value,
mask_future_timesteps=False,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
key = value = None
else:
saved_state = None
if qkv_same:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif kv_same:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
# this will allow us to concat it with previous value and get
# just get the previous value
k = v = q.new(0)
else:
k, v = self.in_proj_kv(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
else:
attn_weights = None
return attn, attn_weights