in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]
def forward(self, hidden_states, residual=None):
# Flash attention imports
normed_hidden_states = None
res = None
if self.device.type == "cuda":
import dropout_layer_norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
False,
)
if res is None:
res = hidden_states
elif self.use_ipex:
import intel_extension_for_pytorch as ipex
normed_hidden_states = ipex.llm.functional.add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
residual is not None,
)
res = residual if residual is not None else hidden_states
elif self.device.type == "hpu":
normed_hidden_states = hpu_add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
residual is not None,
)
res = residual if residual is not None else hidden_states
return normed_hidden_states, res