in src/model.py [0:0]
def cross_attention_forward(
self,
input,
mask=None,
kv=None,
position_bias=None,
past_key_value_state=None,
head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
"""
This only works for computing cross attention over the input
"""
assert(kv != None)
assert(head_mask == None)
assert(position_bias != None or self.has_relative_attention_bias)
bsz, qlen, dim = input.size()
n_heads, d_heads = self.n_heads, self.d_kv
klen = kv.size(1)
q = self.q(input).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
if past_key_value_state == None:
k = self.k(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
v = self.v(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
else:
k, v = past_key_value_state
scores = torch.einsum("bnqd,bnkd->bnqk", q, k)
if mask is not None:
scores += mask
if position_bias is None:
position_bias = self.compute_bias(qlen, klen)
scores += position_bias
if self.score_storage is None:
self.score_storage = scores
attn = F.softmax(scores.float(), dim=-1).type_as(scores)
attn = F.dropout(attn, p=self.dropout, training=self.training)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(bsz, -1, self.inner_dim)
output = self.o(output)
if use_cache:
output = (output,) + ((k, v),)
else:
output = (output,) + (None,)
if output_attentions:
output = output + (attn,)
if self.has_relative_attention_bias:
output = output + (position_bias,)
return output