in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
residual = hidden_states
qkv = F.linear(hidden_states, self.qkv_weight.T, self.qkv_bias)
bs = 1
hidden_dim = hidden_states.size(-1)
is_flat = True
if hidden_states.dim() > 2:
is_flat = False
bs = hidden_states.size(0)
q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=2
)
else:
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
self.num_heads, dim=1
)
attn_output = torch.empty_like(q)
attention(
q,
k,
v,
attn_output,
cu_seqlens,
max_s,
self.softmax_scale,
attn_mask=attn_mask,
)
hidden_states = torch.addmm(
self.dense_bias,
attn_output.view(-1, self.num_heads * self.head_size),
self.dense_weight,
)
if not is_flat:
hidden_states = hidden_states.view(bs, -1, hidden_dim)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
return hidden_states