in optimum/exporters/openvino/model_patcher.py [0:0]
def _gptj_attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
return self._orig_attn(query, key, value, attention_mask, head_mask)
batch_size = query.shape[0]
mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)
# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if getattr(self, "downcast_qk", False):
query = query.to(value.dtype)
key = key.to(value.dtype)
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
return self._orig_attn(query, key, value, attention_mask, head_mask)
dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
if not is_transformers_version(">=", "4.44.99"):
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = torch.where(causal_mask, 0, mask_value)
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
if attention_mask is not None:
attention_mask = causal_mask + attention_mask
else:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)
# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if getattr(self, "downcast_qk", False):
sdpa_result = sdpa_result.to(value.dtype)
return sdpa_result, None