in optimum/exporters/openvino/model_patcher.py [0:0]
def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None:
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
# difference with original modeling
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
# during execution on platforms with default lower precision (bfloat16, float16)
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
# difference with original modeling
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
if past_seen_tokens is not None:
current_length = past_seen_tokens + sequence_length + 1
# TODO : remove after support of transformers >= v4.40.0
else:
current_length = cache_position[-1] + 1
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length
# difference with original modeling
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask