in optimum/exporters/openvino/model_patcher.py [0:0]
def _jais_attn(self, query, key, value, attention_mask=None, head_mask=None, position_bias=None):
scale = 1.0
if self.scale_attn_weights:
scale = 1 / self.head_dim**self.attn_scale_power
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
scale = scale / float(self.layer_idx + 1)
query_length = query.size(-2)
attention_mask_sdpa = torch.ones(
(query.shape[0], query.shape[1], query.shape[2], key.shape[2]),
dtype=query.dtype,
)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(torch.float16).min
attention_mask_sdpa.masked_fill_(~causal_mask, mask_value)
if attention_mask is not None:
# Apply the attention mask
attention_mask_sdpa = attention_mask_sdpa + attention_mask
if position_bias is not None:
attention_mask_sdpa += position_bias.type_as(attention_mask_sdpa).unsqueeze(0)
# Mask heads if we want to
if head_mask is not None:
attention_mask_sdpa = attention_mask_sdpa * head_mask
attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask_sdpa, dropout_p=self.attn_dropout.p, scale=scale
)
return attn_output, None