def forward()

in deepseek_vl2/models/siglip_vit.py [0:0]


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        from xformers.ops import memory_efficient_attention

        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)

        if not self.qk_norm:
            if self.head_dim % 32 == 0 and is_flash_attn_2_available():
                # flashattn must have head_dim as a multiple of 32
                x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
                                              deterministic=self.deterministic)
            else:
                q, k, v = qkv.unbind(2)
                x = memory_efficient_attention(q, k, v, p=self.attn_drop.p if self.training else 0.)
            x = x.reshape(B, N, C)
            x = self.proj(x)
            x = self.proj_drop(x)
            return x

        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=False):
                # 用上下文的方式强行使用fa
                x = F.scaled_dot_product_attention(
                    q, k, v,
                    dropout_p=self.attn_drop.p if self.training else 0.,
                )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x