csrc/custom_ops/kernels.cu (135 lines of code) (raw):

#include "custom_ops.h" #include "dispatch_utils.h" #include "quant_utils.cuh" #include <torch/cuda.h> #include <c10/cuda/CUDAGuard.h> #include <vector> namespace vllm { template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> __global__ void reshape_and_cache_flash_bulk_kernel( const scalar_t* __restrict__ keys, const scalar_t* __restrict__ values, int64_t* key_cache_ptrs, int64_t* value_cache_ptrs, const int64_t* __restrict__ slot_mapping, const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, int64_t* k_scale_ptrs, int64_t* v_scale_ptrs) { const int64_t layer_idx = blockIdx.x; const int64_t token_idx = blockIdx.y; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded if (slot_idx < 0) { return; } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; const int n = num_heads * head_size; cache_t* __restrict__ key_cache = reinterpret_cast<cache_t*>(key_cache_ptrs[layer_idx]); cache_t* __restrict__ value_cache = reinterpret_cast<cache_t*>(value_cache_ptrs[layer_idx]); const float* __restrict__ k_scale = reinterpret_cast<const float*>(k_scale_ptrs[layer_idx]); const float* __restrict__ v_scale = reinterpret_cast<const float*>(v_scale_ptrs[layer_idx]); for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + layer_idx * n + i; const int64_t src_value_idx = token_idx * value_stride + layer_idx * n + i; const int head_idx = i / head_size; const int head_offset = i % head_size; const int64_t tgt_key_value_idx = block_idx * block_stride + block_offset * num_heads * head_size + head_idx * head_size + head_offset; scalar_t tgt_key = keys[src_key_idx]; scalar_t tgt_value = values[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_value_idx] = tgt_key; value_cache[tgt_key_value_idx] = tgt_value; } else { key_cache[tgt_key_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale); value_cache[tgt_key_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale); } } } } // namespace vllm #define CALL_RESHAPE_AND_CACHE_FLASH_BULK(KV_T, CACHE_T, KV_DTYPE) \ vllm::reshape_and_cache_flash_bulk_kernel<KV_T, CACHE_T, KV_DTYPE> \ <<<grid, block, 0, stream>>>( \ reinterpret_cast<KV_T*>(keys.data_ptr()), \ reinterpret_cast<KV_T*>(values.data_ptr()), \ key_cache_ptrs_tensor.data_ptr<int64_t>(), \ value_cache_ptrs_tensor.data_ptr<int64_t>(), \ slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \ value_stride, static_cast<int>(num_heads), \ static_cast<int>(head_size), block_size, \ k_scale_ptrs_tensor.data_ptr<int64_t>(), \ v_scale_ptrs_tensor.data_ptr<int64_t>()); void reshape_and_cache_flash_bulk( torch::Tensor& keys, torch::Tensor& values, std::vector<torch::Tensor> const& key_caches, std::vector<torch::Tensor> const& value_caches, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, std::vector<torch::Tensor> const& k_scales, std::vector<torch::Tensor> const& v_scales, int64_t num_heads, int64_t head_size) { int num_layers = key_caches.size(); if (num_layers == 0) { return; } TORCH_CHECK(num_layers == key_caches.size()); TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == k_scales.size()); TORCH_CHECK(num_layers == v_scales.size()); int num_tokens = slot_mapping.size(0); int block_size = key_caches[0].size(1); int key_stride = keys.stride(0); int value_stride = values.stride(0); int block_stride = key_caches[0].stride(0); TORCH_CHECK(block_stride == value_caches[0].stride(0)); int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; int64_t k_scale_ptrs[num_layers]; int64_t v_scale_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); k_scale_ptrs[layer_idx] = reinterpret_cast<int64_t>(k_scales[layer_idx].data_ptr()); v_scale_ptrs[layer_idx] = reinterpret_cast<int64_t>(v_scales[layer_idx].data_ptr()); } torch::Device device_of_key = keys.device(); const at::cuda::OptionalCUDAGuard device_guard(device_of_key); torch::Tensor key_cache_ptrs_tensor = torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) .to(device_of_key); torch::Tensor value_cache_ptrs_tensor = torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) .to(device_of_key); torch::Tensor k_scale_ptrs_tensor = torch::from_blob(k_scale_ptrs, {num_layers}, torch::kInt64) .to(device_of_key); torch::Tensor v_scale_ptrs_tensor = torch::from_blob(v_scale_ptrs, {num_layers}, torch::kInt64) .to(device_of_key); dim3 grid(num_layers, num_tokens); dim3 block(std::min(static_cast<int>(num_heads) * static_cast<int>(head_size), 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_BY_KV_CACHE_DTYPE(keys.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE_FLASH_BULK); }