in cm/unet.py [0:0]
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
qkv = qkv.half()
qkv = self.rearrange(
qkv, "b (three h d) s -> b s three h d", three=3, h=self.n_heads
)
q, k, v = qkv.transpose(1, 3).transpose(3, 4).split(1, dim=2)
q = q.reshape(bs*self.n_heads, ch, length)
k = k.reshape(bs*self.n_heads, ch, length)
v = v.reshape(bs*self.n_heads, ch, length)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight, dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
a = a.float()
return a.reshape(bs, -1, length)