in glide_text2im/xf.py [0:0]
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.n_heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
q, k, v = th.split(qkv, attn_ch, dim=-1)
weight = th.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = th.softmax(weight.float(), dim=-1).type(wdtype)
return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)