def _custom_sdpa_for_ring_kv_cache()

in optimum/executorch/attentions/custom_sdpa.py [0:0]


    def _custom_sdpa_for_ring_kv_cache(
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Union[torch.Tensor, "BlockMask"],  # noqa
        scaling: Optional[float] = None,
        softcap: Optional[float] = None,
        head_mask: Optional[torch.Tensor] = None,
        **kwargs,