def forward()

in glide_text2im/clip/encoders.py [0:0]


    def forward(self, m: torch.Tensor) -> torch.Tensor:
        n_context = m.shape[1]
        n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context
        n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context
        assert n_query_pad >= 0
        assert n_key_pad >= 0

        r = m
        r = self.ln(r)
        q, k, v = self.f_q(r), self.f_k(r), self.f_v(r)

        if n_query_pad != 0:
            q = F.pad(q, (0, 0, 0, n_query_pad))

        if n_key_pad != 0:
            k = F.pad(k, (0, 0, 0, n_key_pad))
            v = F.pad(v, (0, 0, 0, n_key_pad))

        q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
        k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
        v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
        w = torch.einsum(
            "bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale)
        )

        if hasattr(self.attn_fn, "pytorch_attn_bias"):
            bias = self.attn_fn.pytorch_attn_bias
            assert len(bias.shape) in {2, 3}

            if len(bias.shape) == 2:
                w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1)
            elif len(bias.shape) == 3:
                w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1)
        else:
            w = torch.softmax(w, dim=-1)

        r = torch.einsum("bhck,bhkd->bhcd", w, v)
        r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state))

        if n_query_pad != 0:
            r = r[:, :-n_query_pad]

        assert r.shape[1] == n_context

        r = self.f_c(r)
        return m + r