in optimum/bettertransformer/models/encoder_models.py [0:0]
def forward(self, hidden_states, attention_mask, output_attentions: bool, position_bias=None, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
else:
original_shape = hidden_states.original_shape
if hidden_states.is_nested:
attention_mask = None
if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
if len(attention_mask.shape) == 4:
attention_mask = attention_mask.squeeze(1)[:, 0]
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None
hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if not self.is_last_layer:
hidden_states.original_shape = original_shape
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
else:
qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias)
qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4)
query, key, value = qkv[0], qkv[1], qkv[2]
# NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch
# to the "math" path and will NOT use flash attention / memory-efficient attention.
# We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work.
if self.training:
attention_mask = None
attention_out = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
is_causal=False,
dropout_p=self.dropout if self.training else 0.0,
)
attention_out = attention_out.permute(0, 2, 1, 3).contiguous()
new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,)
attention_out = attention_out.view(new_attention_out_shape)
# BertSelfOutput
attention_out = F.layer_norm(
F.dropout(
F.linear(attention_out, self.out_proj_weight, self.out_proj_bias),
p=self.dropout,
training=self.training,
)
+ hidden_states,
normalized_shape=self.norm1_weight.shape,
weight=self.norm1_weight,
bias=self.norm1_bias,
)
# One additional dropout compared to bert
hidden_states = F.dropout(
self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias)),
p=self.activation_dropout,
training=self.training,
)
hidden_states = F.layer_norm(
attention_out
+ F.dropout(
F.linear(hidden_states, self.linear2_weight, self.linear2_bias),
p=self.dropout,
training=self.training,
),
normalized_shape=self.norm2_weight.shape,
weight=self.norm2_weight,
bias=self.norm2_bias,
)
return (hidden_states,)