def forward()

in backends/python/server/text_embeddings_server/models/flash_bert.py [0:0]


    def forward(self, hidden_states, residual=None):
        # Flash attention imports
        normed_hidden_states = None
        res = None
        if self.device.type == "cuda":
            import dropout_layer_norm

            normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
                hidden_states,
                residual,
                self.weight,
                self.bias,
                None,
                None,
                None,
                None,
                0.0,
                self.variance_epsilon,
                1.0,
                0,
                None,
                False,
                False,
            )
            if res is None:
                res = hidden_states
        elif self.use_ipex:
            import intel_extension_for_pytorch as ipex

            normed_hidden_states = ipex.llm.functional.add_layer_norm(
                residual,
                hidden_states,
                self.weight,
                self.bias,
                self.variance_epsilon,
                residual is not None,
            )

            res = residual if residual is not None else hidden_states
        elif self.device.type == "hpu":
            normed_hidden_states = hpu_add_layer_norm(
                residual,
                hidden_states,
                self.weight,
                self.bias,
                self.variance_epsilon,
                residual is not None,
            )
            res = residual if residual is not None else hidden_states
        return normed_hidden_states, res