in glide_text2im/clip/encoders.py [0:0]
def forward(self, m: torch.Tensor) -> torch.Tensor:
n_context = m.shape[1]
n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context
n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context
assert n_query_pad >= 0
assert n_key_pad >= 0
r = m
r = self.ln(r)
q, k, v = self.f_q(r), self.f_k(r), self.f_v(r)
if n_query_pad != 0:
q = F.pad(q, (0, 0, 0, n_query_pad))
if n_key_pad != 0:
k = F.pad(k, (0, 0, 0, n_key_pad))
v = F.pad(v, (0, 0, 0, n_key_pad))
q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
w = torch.einsum(
"bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale)
)
if hasattr(self.attn_fn, "pytorch_attn_bias"):
bias = self.attn_fn.pytorch_attn_bias
assert len(bias.shape) in {2, 3}
if len(bias.shape) == 2:
w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1)
elif len(bias.shape) == 3:
w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1)
else:
w = torch.softmax(w, dim=-1)
r = torch.einsum("bhck,bhkd->bhcd", w, v)
r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state))
if n_query_pad != 0:
r = r[:, :-n_query_pad]
assert r.shape[1] == n_context
r = self.f_c(r)
return m + r