in src/model/transformer.py [0:0]
def forward(self, input, mask, kv=None, use_cache=False):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
Input is (bs, qlen, dim)
Mask is (bs, klen) (non-causal) or (bs, klen, klen)
"""
assert not (use_cache and self.cache is None)
bs, qlen, dim = input.size()
if kv is None:
klen = qlen if not use_cache else self.cache['slen'] + qlen
else:
klen = kv.size(1)
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads
dim_per_head = dim // n_heads
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
def shape(x):
""" projection """
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
""" compute context """
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
elif not use_cache or self.layer_id not in self.cache:
k = v = kv
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
if use_cache:
if self.layer_id in self.cache:
if kv is None:
k_, v_ = self.cache[self.layer_id]
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = self.cache[self.layer_id]
self.cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
if TransformerModel.STORE_OUTPUTS and not self.training:
self.outputs = weights.detach().cpu()
return self.out_lin(context)