kernels/cache_kernels.cu (92 lines of code) (raw):

#include <cuda_fp16.h> #include <cuda_bf16.h> #include <stdint.h> #include "cuda_compat.h" #include <algorithm> #include <cassert> #include <map> #include <vector> namespace vllm { template<typename scalar_t> __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { // Padding token that should be ignored. return; } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + i; const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; const int x_idx = head_offset / x; const int x_offset = head_offset % x; const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset; const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset; key_cache[tgt_key_idx] = key[src_key_idx]; value_cache[tgt_value_idx] = value[src_value_idx]; } } #define CALL_RESHAPE_AND_CACHE(T) \ vllm::reshape_and_cache_kernel<T><<<grid, block, 0, stream>>>( \ reinterpret_cast<T*>(key), \ reinterpret_cast<T*>(value), \ reinterpret_cast<T*>(key_cache), \ reinterpret_cast<T*>(value_cache), \ slot_mapping, \ key_stride, \ value_stride, \ num_heads, \ head_size, \ block_size, \ x); } // namespace vllm extern "C" void reshape_and_cache( void *key, // [num_tokens, num_heads, head_size] void *value, // [num_tokens, num_heads, head_size] void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] void *value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t* slot_mapping, // [num_tokens] int32_t num_tokens, int32_t num_heads, int32_t head_size, int32_t block_size, int32_t x, int32_t key_stride, int32_t value_stride, uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 ) { dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = 0; if (dtype == 0){ CALL_RESHAPE_AND_CACHE(uint16_t); } else if (dtype == 1) { CALL_RESHAPE_AND_CACHE(__nv_bfloat16); } else if (dtype == 2) { CALL_RESHAPE_AND_CACHE(float); } }