def gpt_bigcode_attn()

in optimum/exporters/openvino/model_patcher.py [0:0]


def gpt_bigcode_attn(self, query, key, value, attention_mask=None, head_mask=None):
    if head_mask is not None:
        # The super dispatch is done in the forward.
        raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.")

    scale = None
    if not self.scale_attn_weights:
        scale = 1

    # MQA models: (batch_size, query_length, num_heads * head_dim)
    # MHA models: (batch_size, num_heads, query_length, head_dim)
    query_shape = query.shape
    batch_size = query_shape[0]
    key.shape[-2]

    if self.multi_query:
        query_length = query_shape[1]

        # SDPA requires the dimension [..., sequence_length, head_dim].
        query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)

        # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
        # and flash attention backend (No available kernel.  Aborting execution.) from the shapes
        # query = [batch_size, num_heads, query_length, head_dim]
        # key = [batch_size, 1, past_length, head_dim]
        # value = [batch_size, 1, past_length, head_dim]
        #
        # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
        if is_torch_version(">=", "2.2.0"):
            key = key.expand(-1, self.num_heads, -1, -1)
            value = value.expand(-1, self.num_heads, -1, -1)
    else:
        query_length = query_shape[-1]

        # See the comment above.
        if query.device.type == "cuda" and attention_mask is not None:
            query = query.contiguous()
            key = key.contiguous()
            value = value.contiguous()

    # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
    # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
    # create a causal mask in case query_length == 1.
    is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
    # different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
    if attention_mask is not None:
        attention_mask = attention_mask.to(query.dtype)
    sdpa_result = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=attention_mask,
        dropout_p=self.attn_pdrop if self.training else 0.0,
        is_causal=is_causal,
        scale=scale,
    )

    if self.multi_query:
        # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
        sdpa_result = sdpa_result.transpose(1, 2)

        # Reshape is kind of expensive here, as it does a memory copy,
        # but I did not manage to make away without it (logits do not match when using view)
        # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
        sdpa_result = sdpa_result.reshape(query_shape)

    return sdpa_result, None