backends/python/server/text_embeddings_server/utils/flash_attn.py (167 lines of code) (raw):

import os import torch from text_embeddings_server.utils.device import use_ipex, is_hpu from loguru import logger if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") HAS_FLASH_ATTN = False HAS_FLASH_ATTN_V2 = False is_hpu = is_hpu() use_ipex = use_ipex() if use_ipex or is_hpu: HAS_FLASH_ATTN_V2 = True else: if not torch.cuda.is_available(): raise ImportError("CUDA is not available") major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 is_sm90 = major == 9 and minor == 0 try: try: import flash_attn_2_cuda except ImportError: raise ImportError( "Flash Attention V2 is not installed.\n" "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" ) if not (is_sm8x or is_sm90): raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported for " "Flash Attention V2" ) HAS_FLASH_ATTN_V2 = True except ImportError as e: try: import flash_attn_cuda except ImportError: raise ImportError( "Flash Attention is not installed.\n" "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " "or install flash attention with `cd server && make install install-flash-attention`" ) from e if not (is_sm75 or is_sm8x or is_sm90): raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported" ) from e logger.warning(f"Unable to use Flash Attention V2: {e}") HAS_FLASH_ATTN = True def hpu_attn( q, k, v, out, attn_mask, seqlen_q, seqlen_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal=False, ): from habana_frameworks.torch.hpex.kernels import FusedSDPA q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) if is_causal: attn_mask = None out_ = FusedSDPA.apply( q, k, v, attn_mask, 0.0, is_causal, softmax_scale, "fast", False ) out_ = out_.transpose(1, 2) out.copy_(out_) return out def attention( q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False, attn_mask=None ): if HAS_FLASH_ATTN_V2: if use_ipex: import intel_extension_for_pytorch as ipex if q.device.type == "xpu": return ipex.llm.functional.varlen_attention( q.contiguous(), k.contiguous(), v.contiguous(), out, cu_seqlens, cu_seqlens, None, max_s, max_s, 0, softmax_scale, zero_tensors=False, is_causal=False, return_softmax=False, gen_=None, ) elif q.device.type == "cpu": return ipex.llm.functional.varlen_attention( q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0, softmax_scale, zero_tensors=False, is_causal=False, return_softmax=False, gen_=None, ) elif is_hpu: return hpu_attn( q, k, v, out, attn_mask, cu_seqlens, cu_seqlens, max_s, max_s, softmax_scale, is_causal, ) else: return flash_attn_2_cuda.varlen_fwd( q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, softmax_scale, False, is_causal, -1, -1, False, None, ) if HAS_FLASH_ATTN: return flash_attn_cuda.fwd( q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, softmax_scale, False, is_causal, False, 0, None, ) raise NotImplementedError("flash attention is not installed")