backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py (138 lines of code) (raw):
from typing import Tuple
from dataclasses import dataclass, field
import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = (
torch.zeros(
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale,
kv_scales.value_scale,
)
class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""
kv_cache: torch.Tensor
def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks * BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache
@property
def value(self):
"""Get the value cache."""
return self.kv_cache
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
self.kv_cache.index_copy_(0, slots, key)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
):
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
)[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
key_cache.index_copy_(0, slots, key)
value_cache.index_copy_(0, slots, value)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)