def _gptj_attn()

in optimum/exporters/openvino/model_patcher.py [0:0]


def _gptj_attn(self, query, key, value, attention_mask=None, head_mask=None):
    if head_mask is not None:
        return self._orig_attn(query, key, value, attention_mask, head_mask)

    batch_size = query.shape[0]

    mask_value = torch.finfo(value.dtype).min
    mask_value = torch.full([], mask_value, dtype=value.dtype)

    # in gpt-neo-x and gpt-j the query and keys are always in fp32
    # thus we need to cast them to the value dtype
    if getattr(self, "downcast_qk", False):
        query = query.to(value.dtype)
        key = key.to(value.dtype)

    if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
        return self._orig_attn(query, key, value, attention_mask, head_mask)

    dropout_p = self.dropout_prob_attn if self.training else 0.0
    if batch_size == 1 or self.training:
        if query.shape[2] > 1:
            sdpa_result = torch.nn.functional.scaled_dot_product_attention(
                query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
            )
        else:
            sdpa_result = torch.nn.functional.scaled_dot_product_attention(
                query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
            )
    else:
        query_length, key_length = query.size(-2), key.size(-2)

        # causal_mask is always [True, ..., True] otherwise, so executing this
        # is unnecessary
        if query_length > 1:
            if not is_transformers_version(">=", "4.44.99"):
                causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)

                causal_mask = torch.where(causal_mask, 0, mask_value)

                # torch.Tensor.expand does no memory copy
                causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
                if attention_mask is not None:
                    attention_mask = causal_mask + attention_mask

            else:
                attention_mask = attention_mask[:, :, :, : key.shape[-2]]

        sdpa_result = torch.nn.functional.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
        )

    # in gpt-neo-x and gpt-j the query and keys are always in fp32
    # thus we need to cast them to the value dtype
    if getattr(self, "downcast_qk", False):
        sdpa_result = sdpa_result.to(value.dtype)

    return sdpa_result, None