candle-rotary/kernels/rotary.cu (116 lines of code) (raw):
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdint.h>
#include "cuda_compat.h"
namespace vllm {
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int rot_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = rot_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_cache, // [num_tokens, rot_dim]
const scalar_t* __restrict__ sin_cache, // [num_tokens, rot_dim]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
const scalar_t* cos_ptr = cos_cache + token_idx * rot_dim;
const scalar_t* sin_ptr = sin_cache + token_idx * rot_dim;
const int nq = num_heads * rot_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
const int nk = num_kv_heads * rot_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
}
} // namespace vllm
#define CALL_ROTARY(T, IS_NEOX) \
vllm::rotary_embedding_kernel<T, IS_NEOX><<<grid, block, 0, stream>>>( \
reinterpret_cast<T*>(query), \
reinterpret_cast<T*>(key), \
reinterpret_cast<T*>(cos_cache), \
reinterpret_cast<T*>(sin_cache), \
rot_dim, \
query_stride, \
key_stride, \
num_heads, \
num_kv_heads, \
head_size);
extern "C" void rotary_embedding(
void *query, // [num_tokens, num_heads, head_size]
void *key, // [num_tokens, num_kv_heads, head_size]
void *cos_cache, // [num_tokens, rot_dim]
void *sin_cache, // [num_tokens, rot_dim]
int32_t is_neox,
int32_t head_size,
int64_t num_tokens,
int32_t rot_dim,
int32_t num_heads,
int32_t num_kv_heads,
int64_t query_stride,
int64_t key_stride,
uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32
) {
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim, 512));
const cudaStream_t stream = 0;
const bool is_neox_bool = is_neox;
if (is_neox_bool) {
if (dtype == 0){
CALL_ROTARY(half, true);
} else if (dtype == 1) {
CALL_ROTARY(__nv_bfloat16, true);
} else if (dtype == 2) {
CALL_ROTARY(float, true);
}
} else {
if (dtype == 0){
CALL_ROTARY(half, false);
} else if (dtype == 1) {
CALL_ROTARY(__nv_bfloat16, false);
} else if (dtype == 2) {
CALL_ROTARY(float, false);
}
}
}