def forward()

in optimum/bettertransformer/models/encoder_models.py [0:0]


    def forward(self, hidden_states, attention_mask, *_):
        if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
            if hidden_states.is_nested:
                attention_mask = None

            if attention_mask is not None:
                # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
                # 0->false->keep this token -inf->true->mask this token
                attention_mask = attention_mask.bool()
                attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
                hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
                attention_mask = None

            hidden_states = torch._transformer_encoder_layer_fwd(
                hidden_states,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.out_proj_weight,
                self.out_proj_bias,
                self.use_gelu,
                self.norm_first,
                self.norm1_eps,
                self.norm1_weight,
                self.norm1_bias,
                self.norm2_weight,
                self.norm2_bias,
                self.linear1_weight,
                self.linear1_bias,
                self.linear2_weight,
                self.linear2_bias,
                attention_mask,
            )
            if hidden_states.is_nested and self.is_last_layer:
                hidden_states = hidden_states.to_padded_tensor(0.0)
        else:
            qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias)

            qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4)
            query, key, value = qkv[0], qkv[1], qkv[2]

            # NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch
            # to the "math" path and will NOT use flash attention / memory-efficient attention.
            # We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work.
            if self.training:
                attention_mask = None
            attention_out = F.scaled_dot_product_attention(
                query,
                key,
                value,
                attn_mask=attention_mask,
                is_causal=False,
                dropout_p=self.attention_probs_dropout_prob if self.training else 0.0,
            )

            attention_out = attention_out.permute(0, 2, 1, 3).contiguous()
            new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,)
            attention_out = attention_out.view(new_attention_out_shape)

            # BertSelfOutput
            attention_out = F.layer_norm(
                F.dropout(
                    F.linear(attention_out, self.out_proj_weight, self.out_proj_bias),
                    p=self.hidden_dropout_prob,
                    training=self.training,
                )
                + hidden_states,
                normalized_shape=self.norm1_weight.shape,
                weight=self.norm1_weight,
                bias=self.norm1_bias,
            )

            # BertIntermediate
            hidden_states = self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias))

            # BertOutput
            hidden_states = F.layer_norm(
                attention_out
                + F.dropout(
                    F.linear(hidden_states, self.linear2_weight, self.linear2_bias),
                    p=self.hidden_dropout_prob,
                    training=self.training,
                ),
                normalized_shape=self.norm2_weight.shape,
                weight=self.norm2_weight,
                bias=self.norm2_bias,
            )

        return (hidden_states,)