def forward()

in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]


    def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
        residual = hidden_states
        qkv = F.linear(hidden_states, self.qkv_weight.T, self.qkv_bias)
        bs = 1
        hidden_dim = hidden_states.size(-1)
        is_flat = True
        if hidden_states.dim() > 2:
            is_flat = False
            bs = hidden_states.size(0)
            q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split(
                self.num_heads, dim=2
            )
        else:
            q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
                self.num_heads, dim=1
            )
        attn_output = torch.empty_like(q)
        attention(
            q,
            k,
            v,
            attn_output,
            cu_seqlens,
            max_s,
            self.softmax_scale,
            attn_mask=attn_mask,
        )

        hidden_states = torch.addmm(
            self.dense_bias,
            attn_output.view(-1, self.num_heads * self.head_size),
            self.dense_weight,
        )
        if not is_flat:
            hidden_states = hidden_states.view(bs, -1, hidden_dim)
        hidden_states, _ = self.layer_norm.forward(hidden_states, residual)

        return hidden_states