# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Tuple, Union

import torch
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa  # noqa


def custom_sdpa_with_start_pos_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Union[torch.Tensor, "BlockMask"],  # noqa
    scaling: Optional[float] = None,
    softcap: Optional[float] = None,
    head_mask: Optional[torch.Tensor] = None,
    **kwargs,
) -> Tuple[torch.Tensor, None]:
    # This is before the transpose
    max_seq_len = key.shape[2]

    # FA2 uses non-transposed inputs
    query = query.transpose(1, 2)
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)

    # Convert the hell out of the inputs to fp32 and back
    input_dtype = query.dtype
    query = query.to(torch.float32)
    key = key.to(torch.float32)
    value = value.to(torch.float32)

    # Ignore the causal flag from kwargs but use the one in module
    kwargs.pop("is_causal", None)
    assert module.is_causal, "Current variant supports only causal attention"

    is_causal = module.is_causal
    if kwargs.get("is_sliding", False):
        is_causal = False
        attn_mask = attention_mask
        # start_pos is not important when using mask
        # instead of doing causal attention
        start_pos = 0
    else:
        attn_mask = None
        # Calculate the input pos from attention mask.
        # Branch out for float vs bool mask
        # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
        attention_mask = attention_mask.reshape(-1, max_seq_len)
        first_row_mask = attention_mask[0, :]
        # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
        start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1

    output = torch.ops.llama.custom_sdpa(
        query,
        key,
        value,
        start_pos=start_pos,
        attn_mask=attn_mask,
        drpout_p=0.0,
        is_causal=is_causal,
        scale=scaling,
    )
    return output.to(input_dtype), None


def get_custom_sdpa_for_ring_kv_cache(
    exportable_module: torch.nn.Module,
) -> Callable:
    # lazy importing to avoid version dependent class definition
    from executorch import version

    try:
        from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
            CustomRingKVCache,
        )
    except ImportError:
        raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.")

    def _custom_sdpa_for_ring_kv_cache(
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Union[torch.Tensor, "BlockMask"],  # noqa
        scaling: Optional[float] = None,
        softcap: Optional[float] = None,
        head_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, None]:
        is_sliding = getattr(module, "is_sliding", False)
        if is_sliding:
            # lazy import to avoid being in the optimum import path
            # for et <= 0.6.0 version
            from optimum.executorch.attentions.custom_kv_cache import ETCustomHybridCache

            layer_idx = module.layer_idx
            assert layer_idx is not None, "layer_idx is not set for sliding window attention."
            hybrid_cache = exportable_module.model.cache
            assert isinstance(hybrid_cache, ETCustomHybridCache), f"Expected HybridCache, got {type(hybrid_cache)}"
            ring_cache = hybrid_cache.get_layer_cache(layer_idx)
            assert isinstance(ring_cache, CustomRingKVCache), f"Expected CustomRingKVCache, got {type(ring_cache)}"
            input_pos = hybrid_cache.cache_position[0].item()
            seqlen = query.shape[2]
            attention_mask = ring_cache.create_causal_mask_for_ring_buffer(input_pos, seqlen)
            kwargs.update({"is_sliding": True})
            return custom_sdpa_with_start_pos_forward(
                module,
                query,
                key,
                value,
                attention_mask,
                scaling,
                softcap,
                head_mask,
                **kwargs,
            )
        else:
            return custom_sdpa_with_start_pos_forward(
                module,
                query,
                key,
                value,
                attention_mask,
                scaling,
                softcap,
                head_mask,
                **kwargs,
            )

    return _custom_sdpa_for_ring_kv_cache
