maga_transformer/cpp/kernels/unfused_attention_kernels.cu (2,425 lines of code) (raw):
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "maga_transformer/cpp/utils/utils.h"
#include "maga_transformer/cpp/kernels/kv_cache/kv_cache_utils.h"
#include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh"
#include "maga_transformer/cpp/kernels/rotary_position_embedding.h"
#include "maga_transformer/cpp/kernels/unfused_attention_kernels.h"
#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh"
#if USING_CUDA
#include "maga_transformer/cpp/cuda/cuda_utils.h"
#endif
#if USING_ROCM
#include "maga_transformer/cpp/rocm/hip_utils.h"
#endif
#include <cstdlib>
namespace rtp_llm {
__inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
{
return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4;
}
template<typename T>
__global__ void addQKVBiasIA3Transpose(T* q_out,
T* k_out,
T* v_out,
const T* __restrict q_in,
const T* __restrict bias_q,
const T* __restrict k_in,
const T* __restrict bias_k,
const T* __restrict v_in,
const T* __restrict bias_v,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head)
{
const int n = head_num * size_per_head;
const int batch_id = blockIdx.x;
const int word_id = blockIdx.y;
const int row_id = batch_id * seq_len + word_id;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) {
const int head_id = col_id / size_per_head;
const int size_id = col_id % size_per_head;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head
+ word_id * size_per_head + size_id;
const int src_id = row_id * n + col_id;
T q = ldg(&q_in[src_id]);
q_out[target_id] = add(q, ldg(&bias_q[col_id]));
T k = add(ldg(&k_in[src_id]), ldg(&bias_k[col_id]));
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + col_id];
}
k_out[target_id] = k;
T v = add(ldg(&v_in[src_id]), ldg(&bias_v[col_id]));
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + col_id];
}
v_out[target_id] = v;
}
}
__global__ void getSkipLength(int* skip_length, int* prefix_lengths, int batch_size)
{
int min_skip_length = prefix_lengths[0];
for (int i = 1; i < batch_size; i++) {
if (min_skip_length > prefix_lengths[i]) {
min_skip_length = prefix_lengths[i];
}
}
*skip_length = min_skip_length;
}
__global__ void float_to_half_kernel(const float* input, half* output, int size)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
output[idx] = __float2half(input[idx]);
}
}
void float_to_half(const void* input, void* output, int size)
{
const float* float_input = reinterpret_cast<const float*>(input);
half* half_output = reinterpret_cast<half*>(output);
constexpr int THREADS_PER_BLOCK = 256;
int n_blocks = (size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
float_to_half_kernel<<<n_blocks, THREADS_PER_BLOCK>>>(float_input, half_output, size);
cudaDeviceSynchronize();
}
__global__ void
half_to_float_kernel(const __half* __restrict__ input, float* __restrict__ output, const int num_elements)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elements) {
output[idx] = __half2float(input[idx]);
}
}
void half_to_float(const void* input, void* output, const int num_elements)
{
const half* half_input = reinterpret_cast<const half*>(input);
float* float_output = reinterpret_cast<float*>(output);
const int blockSize = 256;
const int gridSize = (num_elements + blockSize - 1) / blockSize;
half_to_float_kernel<<<gridSize, blockSize>>>(half_input, float_output, num_elements);
cudaDeviceSynchronize();
}
template<typename T>
__global__ void QKVIA3Transpose(T* q_out,
T* k_out,
T* v_out,
const T* __restrict q_in,
const T* __restrict k_in,
const T* __restrict v_in,
const int* ia3_tasks,
const T* __restrict ia3_key_weights,
const T* __restrict ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head)
{
const int n = head_num * size_per_head;
const int batch_id = blockIdx.x;
const int word_id = blockIdx.y;
const int row_id = batch_id * seq_len + word_id;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) {
const int head_id = col_id / size_per_head;
const int size_id = col_id % size_per_head;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head
+ word_id * size_per_head + size_id;
const int src_id = row_id * n + col_id;
q_out[target_id] = ldg(&q_in[src_id]);
T k = ldg(&k_in[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + col_id];
}
k_out[target_id] = k;
T v = ldg(&v_in[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + col_id];
}
v_out[target_id] = v;
}
}
template<typename T>
void invokeAddQKVBiasIA3Transpose(T* q_buf,
T* k_buf,
T* v_buf,
T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream)
{
const int k = head_num * size_per_head;
dim3 grid(batch_size, seq_len);
bool is_add_bias = bias_Q != nullptr;
if (sizeof(T) == 4 || k % 2 != 0) {
dim3 block(min(k, 512));
if (is_add_bias) {
addQKVBiasIA3Transpose<T><<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
Q,
bias_Q,
K,
bias_K,
V,
bias_V,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head);
}
else {
QKVIA3Transpose<T><<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
Q,
K,
V,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head);
}
sync_check_cuda_error();
}
else {
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
dim3 block(min(k / 2, 512));
if (is_add_bias) {
addQKVBiasIA3Transpose<T2><<<grid, block, 0, stream>>>((T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
(const T2*)Q,
(const T2*)bias_Q,
(const T2*)K,
(const T2*)bias_K,
(const T2*)V,
(const T2*)bias_V,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2);
}
else {
QKVIA3Transpose<T2><<<grid, block, 0, stream>>>((T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
(const T2*)Q,
(const T2*)K,
(const T2*)V,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2);
}
sync_check_cuda_error();
}
}
#define INSTANTIATEADDQKVBIASIA3TRANSPOSE(T) \
template void invokeAddQKVBiasIA3Transpose(T* q_buf, \
T* k_buf, \
T* v_buf, \
T* Q, \
const T* bias_Q, \
T* K, \
const T* bias_K, \
T* V, \
const T* bias_V, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int* ia3_tasks, \
const T* ia3_key_weights, \
const T* ia3_value_weights, \
cudaStream_t stream)
INSTANTIATEADDQKVBIASIA3TRANSPOSE(float);
INSTANTIATEADDQKVBIASIA3TRANSPOSE(half);
#ifdef ENABLE_BF16
INSTANTIATEADDQKVBIASIA3TRANSPOSE(__nv_bfloat16);
#endif
#undef INSTANTIATEADDQKVBIASTRANSPOSE
template<typename T, typename T_IN, int ITEMS_PER_THREAD>
__global__ void softmax_kernel(T* attn_score,
const T_IN* qk,
const T* attn_mask,
const float* linear_bias_slopes,
const int batch_size,
const int head_num,
const int q_length,
const int k_length,
const float qk_scale)
{
// attn_score, [batch_size, num_heads, q_length, k_length]
// qk, [batch_size, num_heads, q_length, k_length]
// attn_mask, [batch_size, q_length, k_length]
// linear_bias_slopes, [num_heads]
const int bi = blockIdx.y; // Batch index.
const int hi = blockIdx.z; // Head index.
__shared__ float s_mean, s_max;
const float linear_bias_slope = linear_bias_slopes != nullptr ? (float)linear_bias_slopes[hi] : 0.0f;
// Loop along with Q dimension.
for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {
float data[ITEMS_PER_THREAD];
int64_t qk_offset;
float local_max = -1e20f;
// Loop along with K dimension.
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
int ki = blockDim.x * i + threadIdx.x; // Index of K dimension.
qk_offset = ((bi * head_num + hi) * q_length + qi) * static_cast<int64_t>(k_length) + ki;
float qk_val = static_cast<float>(qk[qk_offset]);
float qk_bias = 0.0f;
if (linear_bias_slopes != nullptr) {
// We don't handle the upper diagonal (ki > qi) separately, whose values
// are negligible due to the negative infinity mask. And it matches with
// the HF's implementation.
qk_bias -= static_cast<float>(abs(linear_bias_slope * (ki - qi)));
}
int mask_offset = (bi * q_length + qi) * k_length + ki;
float mask_val = static_cast<float>(ldg(&attn_mask[mask_offset]));
qk_bias += (1.0f - mask_val) * -10000.0f;
data[i] = qk_scale * qk_val + qk_bias;
local_max = fmax(local_max, data[i]);
}
float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0;
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
data[i] = __expf(data[i] - s_max);
local_sum += data[i];
}
float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + blockDim.x * i + threadIdx.x;
attn_score[qk_offset] = (T)(data[i] * s_mean);
}
}
}
template<typename T, int ITEMS_PER_THREAD>
__global__ void softmax_kernel_h2(T* attn_score,
const T* qk_buf,
const T* attn_mask,
const float* linear_bias_slopes,
const int batch_size,
const int head_num,
const int q_length,
const int k_length,
const T qk_scale)
{
// attn_score, [batch_size, num_heads, q_length, k_length]
// qk, [batch_size, num_heads, q_length, k_length]
// attn_mask, [batch_size, q_length, k_length]
// linear_bias_slopes, [num_heads]
using T2 = typename TypeConverter<T>::Type;
T2* attn_score_h2 = reinterpret_cast<T2*>(attn_score);
const T2* qk_buf_h2 = reinterpret_cast<const T2*>(qk_buf);
const T2* attn_mask_h2 = reinterpret_cast<const T2*>(attn_mask);
const int bi = blockIdx.y; // Batch index
const int hi = blockIdx.z; // Head index.
__shared__ float s_mean, s_max;
// Constant values that will be used repeately in the q/k loop.
const T2 ONE = cuda_cast<T2>(1.0f);
const T2 ZERO = cuda_cast<T2>(0.0f);
const T2 NEG_INFTY = cuda_cast<T2>(-10000.0f);
// The normalization factor of QK.
const T2 qk_scale_h2 = cuda_cast<T2>(qk_scale);
// The slope of a linear position bias of the current attention head.
const T2 linear_bias_slope = linear_bias_slopes != nullptr ? cuda_cast<T2>(linear_bias_slopes[hi]) : ZERO;
// Loop over q dimension.
for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {
T2 data[ITEMS_PER_THREAD];
int64_t qk_offset;
float local_max = -1e20f;
// Loop over k dimension.
for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
// The half of the index of k dimension. We will use the elements at {2 * ki, 2 * ki + 1}.
int ki = blockDim.x * i + threadIdx.x;
qk_offset = ((bi * head_num + hi) * q_length + qi) * static_cast<int64_t>(k_length / 2) + ki;
int mask_offset = (bi * q_length + qi) * (k_length / 2) + ki;
// The value of QK^T matrix at (qi, ki).
T2 qk = qk_buf_h2[qk_offset];
// The bias value to the position (qi, ki) including both mask and positional bias.
T2 qk_bias = ZERO;
if (linear_bias_slopes != nullptr) {
// The position bias depends on the distance between qi/ki and is zero if qi >= 2*ki
// or qi >= 2*ki+1. For T2 vectorization, we should handle every two elements along
// with k-dim simultaneously. To do this, we check qi / 2 > ki at ones instead of
// qi >= 2*ki or 2*ki+1. It works because an diagonal element for an odd qi will be
// zero due to slope * (qi - 2*ki+1) = 0. Thus, we don't handle the upper diagonal
// separately, whose values are negligible due to the negative infinity mask.
T2 dist(2.0f * ki - qi, 2.0f * ki + 1 - qi);
qk_bias = hadd2<T2>(qk_bias, -cuda_abs(hmul2<T2>(linear_bias_slope, dist)));
}
T2 mask_val = ldg(&attn_mask_h2[mask_offset]);
qk_bias = hadd2<T2>(qk_bias, hmul2<T2>(hsub2<T2>(ONE, mask_val), NEG_INFTY));
data[i] = hadd2<T2>(hmul2<T2>(qk, qk_scale_h2), qk_bias);
local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y));
}
float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0.0f;
for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
data[i] = hexp2<T2>(hsub2<T2>(data[i], cuda_cast<T2>(s_max)));
local_sum += (float)(data[i].x + data[i].y);
}
float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
qk_offset = ((bi * head_num + hi) * q_length + qi) * (k_length / 2) + blockDim.x * i + threadIdx.x;
attn_score_h2[qk_offset] = hmul2<T2>(data[i], cuda_cast<T2>(s_mean));
}
}
}
template<typename T, int K_ITEMS_PER_THREAD, int Q_ITEMS_PER_THREAD>
__global__ void softmax_kernel_h2_v2(T* attn_score,
const T* qk_buf,
const T* attn_mask,
const float* linear_bias_slopes,
const int batch_size,
const int head_num,
const int q_length,
const int k_length,
const T scalar)
{
// attn_score, [batch_size, num_heads, q_length, k_length]
// qk, [batch_size, num_heads, q_length, k_length]
// attn_mask, [batch_size, q_length, k_length]
// linear_bias_slopes, [num_heads]
using T2 = typename TypeConverter<T>::Type;
// QK^T matrix of shape (batch_size, head_num, q_length, k_length / 2)
T2* attn_score_h2 = reinterpret_cast<T2*>(attn_score);
const T2* qk_buf_h2 = reinterpret_cast<const T2*>(qk_buf);
const T2* attn_mask_h2 = reinterpret_cast<const T2*>(attn_mask);
const int bi = blockIdx.y; // Batch index
const int hi = blockIdx.z; // Head index.
// Constant values that will be used repeately in the q/k loop.
const T2 ONE = cuda_cast<T2>(1.0f);
const T2 ZERO = cuda_cast<T2>(0.0f);
const T2 NEG_INFTY = cuda_cast<T2>(-10000.0f);
// The normalization factor of QK.
const T2 qk_scale = cuda_cast<T2>(scalar);
// The slope of a linear position bias of the current attention head.
const T2 linear_bias_slope = linear_bias_slopes != nullptr ? cuda_cast<T2>(linear_bias_slopes[hi]) : ZERO;
__shared__ float s_sum[Q_ITEMS_PER_THREAD], s_max[Q_ITEMS_PER_THREAD];
// Loop over q dimension.
for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x * Q_ITEMS_PER_THREAD) {
T2 data[Q_ITEMS_PER_THREAD][K_ITEMS_PER_THREAD];
int64_t qk_offset[Q_ITEMS_PER_THREAD];
float local_max[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
local_max[j] = -1e20f;
}
// Loop over k dimension.
const int Q_ITEMS = min((q_length - qi + gridDim.x - 1) / gridDim.x, Q_ITEMS_PER_THREAD);
for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
// The half of the index of k dimension. We will use the elements at {2 * ki, 2 * ki + 1}.
int ki = blockDim.x * i + threadIdx.x;
int mask_offset[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
qk_offset[j] =
((bi * head_num + hi) * q_length + qi + j * gridDim.x) * static_cast<int64_t>(k_length / 2) + ki;
mask_offset[j] = (bi * q_length + qi + j * gridDim.x) * (k_length / 2) + ki;
}
T2 mask_val[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
mask_val[j] = ldg(&attn_mask_h2[mask_offset[j]]);
}
T2 qk[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
qk[j] = qk_buf_h2[qk_offset[j]];
}
T2 pos_bias[Q_ITEMS_PER_THREAD];
if (linear_bias_slopes != nullptr) {
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
// The position bias depends on the distance between qi/ki and is zero if qi >= 2*ki
// or qi >= 2*ki+1. For T2 vectorization, we should handle every two elements along
// with k-dim simultaneously. To do this, we check qi / 2 > ki at ones instead of
// qi >= 2*ki or 2*ki+1. It works because an diagonal element for an odd qi will be
// zero due to slope * (qi - 2*ki+1) = 0. Thus, we don't handle the upper diagonal
// separately, whose values are negligible due to the negative infinity mask.
int qidx = qi + j * gridDim.x;
T2 dist(2.0f * ki - qidx, 2.0f * ki + 1 - qidx);
pos_bias[j] = -cuda_abs(hmul2<T2>(linear_bias_slope, dist));
}
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
mask_val[j] = hmul2<T2>(hsub2<T2>(ONE, mask_val[j]), NEG_INFTY);
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
T2 val = hadd2<T2>(hmul2<T2>(qk_scale, qk[j]), mask_val[j]);
if (linear_bias_slopes != nullptr) {
val = hadd2<T2>(val, pos_bias[j]);
}
data[j][i] = val;
local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y));
}
}
if (blockDim.x <= 32) {
warpReduceMaxV2<float, Q_ITEMS_PER_THREAD>(local_max);
}
else {
blockReduceMaxV2<float, Q_ITEMS_PER_THREAD>(local_max);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
s_max[j] = local_max[j];
}
}
__syncthreads();
float local_sum[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
local_sum[j] = {0.f};
}
for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
#pragma unroll
for (int j = 0; j < Q_ITEMS; ++j) {
data[j][i] = hexp2<T2>(hsub2<T2>(data[j][i], cuda_cast<T2>(s_max[j])));
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
local_sum[j] += (float)(data[j][i].x + data[j][i].y);
}
}
if (blockDim.x <= 32) {
warpReduceSumV2<float, Q_ITEMS_PER_THREAD>(local_sum);
}
else {
blockReduceSumV2<float, Q_ITEMS_PER_THREAD>(local_sum);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
qk_offset[j] = ((bi * head_num + hi) * q_length + qi + j * gridDim.x) * (k_length / 2) + blockDim.x * i
+ threadIdx.x;
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
attn_score_h2[qk_offset[j]] = hmul2<T2>(data[j][i], cuda_cast<T2>(s_sum[j]));
}
}
}
}
#define LAUNCH_MAKSED_SOFTMAX_(T_, ITEMS_PER_THREAD) \
block.x /= ITEMS_PER_THREAD; \
block.x = (block.x + 31) / 32 * 32; \
assert(block.x <= 1024); \
if (is_half2) { \
if (grid.x % 4 == 0) { \
grid.x /= 4; \
softmax_kernel_h2_v2<T_, ITEMS_PER_THREAD, 4> \
<<<grid, block, 0, stream>>>((T_*)param.attention_score, \
(const T_*)param.qk, \
(const T_*)param.attention_mask, \
(const float*)param.linear_bias_slopes, \
param.batch_size, \
param.num_heads, \
param.q_length, \
param.k_length, \
(const T_)param.qk_scale); \
} \
else { \
softmax_kernel_h2<T_, ITEMS_PER_THREAD><<<grid, block, 0, stream>>>((T_*)param.attention_score, \
(const T_*)param.qk, \
(const T_*)param.attention_mask, \
(const float*)param.linear_bias_slopes, \
param.batch_size, \
param.num_heads, \
param.q_length, \
param.k_length, \
(const T_)param.qk_scale); \
} \
} \
else { \
softmax_kernel<T, T_IN, ITEMS_PER_THREAD><<<grid, block, 0, stream>>>(param.attention_score, \
param.qk, \
param.attention_mask, \
param.linear_bias_slopes, \
param.batch_size, \
param.num_heads, \
param.q_length, \
param.k_length, \
param.qk_scale); \
}
#define LAUNCH_MAKSED_SOFTMAX(ITEMS_PER_THREAD) LAUNCH_MAKSED_SOFTMAX_(half, ITEMS_PER_THREAD)
template<typename T, typename T_IN>
void invokeMaskedSoftmax(MaskedSoftmaxParam<T, T_IN>& param, cudaStream_t stream)
{
// attention_score, (batch_size, head_num, q_length, k_length), softmax output.
// qk, (batch_size, head_num, q_length, k_length), QK^T.
// attention_mask, (batch_size, q_length, k_length), attention mask.
// linear_bias_slopes, (head_num,) the slopes of the linear position bias.
dim3 grid(param.q_length, param.batch_size, param.num_heads);
if (param.batch_size * param.num_heads > 360) {
grid.x = ceil(float(param.q_length) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);
RTP_LLM_CHECK(param.k_length <= 32768);
if (block.x > 16384 && block.x <= 32768) {
LAUNCH_MAKSED_SOFTMAX(32)
}
else if (block.x > 8192) {
LAUNCH_MAKSED_SOFTMAX(16)
}
else if (block.x > 4096) {
LAUNCH_MAKSED_SOFTMAX(8)
}
else if (block.x > 2048) {
LAUNCH_MAKSED_SOFTMAX(4)
}
else if (block.x > 1024) {
LAUNCH_MAKSED_SOFTMAX(2)
}
else if (block.x > 0) {
LAUNCH_MAKSED_SOFTMAX(1)
}
}
template void invokeMaskedSoftmax(MaskedSoftmaxParam<float, float>& param, cudaStream_t stream);
template void invokeMaskedSoftmax(MaskedSoftmaxParam<half, float>& param, cudaStream_t stream);
template void invokeMaskedSoftmax(MaskedSoftmaxParam<half, half>& param, cudaStream_t stream);
#ifdef ENABLE_BF16
template<>
void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, float>& param, cudaStream_t stream)
{
// attention_score, (batch_size, head_num, q_length, k_length), softmax output.
// qk, (batch_size, head_num, q_length, k_length), QK^T.
// attention_mask, (batch_size, q_length, k_length), attention mask.
// linear_bias_slopes, (head_num,) the slopes of the linear position bias.
using T = __nv_bfloat16;
using T_IN = float;
dim3 grid(param.q_length, param.batch_size, param.num_heads);
if (param.batch_size * param.num_heads > 360) {
grid.x = ceil(float(param.q_length) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);
RTP_LLM_CHECK(param.k_length <= 8192);
if (block.x > 4096 && block.x <= 8192) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 8);
}
else if (block.x > 2048) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4);
}
else if (block.x > 1024) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 2);
}
else if (block.x > 0) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 1);
}
}
template<>
void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, __nv_bfloat16>& param, cudaStream_t stream)
{
// attention_score, (batch_size, head_num, q_length, k_length), softmax output.
// qk, (batch_size, head_num, q_length, k_length), QK^T.
// attention_mask, (batch_size, q_length, k_length), attention mask.
// linear_bias_slopes, (head_num,) the slopes of the linear position bias.
using T = __nv_bfloat16;
using T_IN = __nv_bfloat16;
dim3 grid(param.q_length, param.batch_size, param.num_heads);
if (param.batch_size * param.num_heads > 360) {
grid.x = ceil(float(param.q_length) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);
RTP_LLM_CHECK(param.k_length <= 8192);
if (block.x > 4096 && block.x <= 8192) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 8);
}
else if (block.x > 2048) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4);
}
else if (block.x > 1024) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 2);
}
else if (block.x > 0) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 1);
}
}
#endif
#undef LAUNCH_MAKSED_SOFTMAX
#undef LAUNCH_MAKSED_SOFTMAX_
template<typename T>
__global__ void transpose(const T* src,
T* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float* scale,
int int8_mode)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int batch_id = tid / (head_num * seq_len * size_per_head);
int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head);
int seq_id = (tid % (seq_len * size_per_head)) / size_per_head;
int id = tid % size_per_head;
int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head);
if (int8_mode == 2) {
using Int8_Packed_T = typename packed_as<int8_t, num_elems<T>::value>::type;
using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type;
const Float_Packed_T scale_val = cuda_cast<Float_Packed_T>(*scale);
reinterpret_cast<Int8_Packed_T*>(dst)[target_id] =
cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(src[tid]) * scale_val);
}
else {
dst[target_id] = src[tid];
}
}
template<>
__global__ void transpose(const float* src,
float* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float* scale,
int int8_mode)
{
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head
+ head_id * size_per_head + threadIdx.x;
const int src_id = blockIdx.x * size_per_head + threadIdx.x;
if (int8_mode == 2) {
const float scale_val = *scale;
reinterpret_cast<int8_t*>(dst)[target_id] = cuda_cast<int8_t>(src[src_id] * scale_val);
}
else {
dst[target_id] = src[src_id];
}
}
template<typename T>
void invokeTransposeQKV(T* dst,
T* src,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
dim3 grid, block;
if (sizeof(T) == 2) {
int seq_per_block = 1;
grid.x = batch_size * head_num * seq_len / seq_per_block;
while (seq_per_block < 4 && grid.x % 2 == 0) {
grid.x /= 2;
seq_per_block *= 2;
}
RTP_LLM_CHECK(grid.x * seq_per_block == (size_t)batch_size * head_num * seq_len);
if (seq_per_block * size_per_head % 2 == 0) {
block.x = seq_per_block * size_per_head / 2;
if (std::is_same<T, half>::value) {
transpose<half2><<<grid, block, 0, stream>>>(
(half2*)src, (half2*)dst, batch_size, seq_len, head_num, size_per_head / 2, scale, int8_mode);
}
#ifdef ENABLE_BF16
else if constexpr (CompileConfig::enable_bf16) {
transpose<__nv_bfloat162><<<grid, block, 0, stream>>>((__nv_bfloat162*)src,
(__nv_bfloat162*)dst,
batch_size,
seq_len,
head_num,
size_per_head / 2,
scale,
int8_mode);
}
#endif
}
else {
block.x = seq_per_block * size_per_head;
transpose<T>
<<<grid, block, 0, stream>>>(src, dst, batch_size, seq_len, head_num, size_per_head, scale, int8_mode);
}
}
else {
const int seq_per_block = 1;
grid.x = batch_size * head_num * seq_len / seq_per_block;
block.x = seq_per_block * size_per_head;
transpose<T>
<<<grid, block, 0, stream>>>(src, dst, batch_size, seq_len, head_num, size_per_head, scale, int8_mode);
}
}
#define INSTANTIATETRANSPOSEQKV(T) \
template void invokeTransposeQKV(T* src, \
T* dst, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const float* scale, \
const int int8_mode, \
cudaStream_t stream)
INSTANTIATETRANSPOSEQKV(float);
INSTANTIATETRANSPOSEQKV(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEQKV(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEQKV
template<typename T>
__global__ void add_QKV_bias_rebuild_padding_ia3(const T* Q,
const T* bias_Q,
const T* K,
const T* bias_K,
const T* V,
const T* bias_V,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset)
{
const int bid = blockIdx.x;
const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int n = head_num * size_per_head;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
const int tgt_head_id = idx / size_per_head;
const int tgt_hidden_id = idx % size_per_head;
const int src_id = bid * n + idx;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head
+ tgt_seq_id * size_per_head + tgt_hidden_id;
q_buf_[tgt_id] = add(ldg(&Q[src_id]), ldg(&bias_Q[idx]));
T k = ldg(&K[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + idx];
}
k_buf_[tgt_id] = add(k, ldg(&bias_K[idx]));
T v = ldg(&V[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + idx];
}
v_buf_[tgt_id] = add(v, ldg(&bias_V[idx]));
}
}
template<typename T>
__global__ void rebuild_padding_ia3(const T* Q,
const T* K,
const T* V,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset)
{
const int bid = blockIdx.x;
const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int n = head_num * size_per_head;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
const int tgt_head_id = idx / size_per_head;
const int tgt_hidden_id = idx % size_per_head;
const int src_id = bid * n + idx;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head
+ tgt_seq_id * size_per_head + tgt_hidden_id;
q_buf_[tgt_id] = ldg(&Q[src_id]);
T k = ldg(&K[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + idx];
}
k_buf_[tgt_id] = k;
T v = ldg(&V[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + idx];
}
v_buf_[tgt_id] = v;
}
}
template<typename T>
void invokeAddQKVBiasIA3RebuildPadding(T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
T* q_buf,
T* k_buf,
T* v_buf,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int valid_word_num,
const int* mask_offset,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream)
{
#ifdef ENABLE_BF16
bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (size_per_head % 2 == 0);
#else
bool is_half2 = (std::is_same<T, half>::value) && (size_per_head % 2 == 0);
#endif
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
int block_size = head_num * size_per_head;
if (is_half2) {
while (block_size > 512) {
if (block_size % 2 == 0) {
block_size /= 2;
}
else {
is_half2 = false;
block_size = std::min(block_size, 512);
break;
}
}
}
else {
block_size = std::min(block_size, 512);
}
if (bias_Q == nullptr && bias_K == nullptr && bias_V == nullptr) {
if (is_half2) {
rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>((T2*)Q,
(T2*)K,
(T2*)V,
(T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2,
mask_offset);
}
else {
rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>(Q,
K,
V,
q_buf,
k_buf,
v_buf,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head,
mask_offset);
}
}
else if (bias_Q != nullptr && bias_K != nullptr && bias_V != nullptr) {
if (is_half2) {
add_QKV_bias_rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>((T2*)Q,
(const T2*)bias_Q,
(T2*)K,
(const T2*)bias_K,
(T2*)V,
(const T2*)bias_V,
(T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2,
mask_offset);
}
else {
add_QKV_bias_rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>(Q,
bias_Q,
K,
bias_K,
V,
bias_V,
q_buf,
k_buf,
v_buf,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head,
mask_offset);
}
}
else {
RTP_LLM_CHECK(false);
}
}
#define INSTANTIATEADDQKVBIASIA3REBUILDPADDING(T) \
template void invokeAddQKVBiasIA3RebuildPadding(T* Q, \
const T* bias_Q, \
T* K, \
const T* bias_K, \
T* V, \
const T* bias_V, \
T* q_buf, \
T* k_buf, \
T* v_buf, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int valid_word_num, \
const int* mask_offset, \
const int* ia3_tasks, \
const T* ia3_key_weights, \
const T* ia3_value_weights, \
cudaStream_t stream)
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(float);
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(half);
#ifdef ENABLE_BF16
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(__nv_bfloat16);
#endif
#undef INSTANTIATEADDQKVBIASREBUILDPADDING
template<typename T>
__global__ void transpose_remove_padding(const T* src,
T* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset,
const float* scale,
const int int8_mode)
{
// TODO: optimize this kernel?
// do remove_sequence_length_padding
const int bid = blockIdx.x; // batch * seq_len or valid_word_num
const int src_batch_id = (bid + mask_offset[bid]) / seq_len;
const int src_seq_id = (bid + mask_offset[bid]) % seq_len;
const int dst_seq_id = bid;
const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head;
const int dst_offset_base = dst_seq_id * head_num * size_per_head;
using Int8_Packed_T = typename packed_as<int8_t, num_elems<T>::value>::type;
using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type;
const Float_Packed_T scale_val =
int8_mode == 2 ? cuda_cast<Float_Packed_T>(*scale) : cuda_cast<Float_Packed_T>(0.0f);
for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) {
const int head_id = idx / size_per_head;
const int hidden_id = idx % size_per_head;
const T src_elem = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]);
if (int8_mode == 2) {
reinterpret_cast<Int8_Packed_T*>(dst)[dst_offset_base + idx] =
cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(src_elem) * scale_val);
}
else {
dst[dst_offset_base + idx] = src_elem;
}
}
}
// clang-format off
template<typename T>
void invokeTransposeAttentionOutRemovePadding(T* src,
T* dst,
const int valid_word_num,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
#ifdef ENABLE_BF16
bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (size_per_head % 2 == 0);
#else
bool is_half2 = (std::is_same<T, half>::value) && (size_per_head % 2 == 0);
#endif
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
int block_size = head_num * size_per_head;
if (is_half2) {
while (block_size > 512) {
if (block_size % 2 == 0) {
block_size /= 2;
}
else {
is_half2 = false;
block_size = std::min(block_size, 1024);
break;
}
}
}
else {
block_size = std::min(block_size, 1024);
}
if (is_half2) {
transpose_remove_padding<T2><<<valid_word_num, block_size, 0, stream>>>(
(T2*)src, (T2*)dst, batch_size, seq_len, head_num, size_per_head / 2, mask_offset, scale, int8_mode);
}
else {
transpose_remove_padding<<<valid_word_num, block_size, 0, stream>>>(
src, dst, batch_size, seq_len, head_num, size_per_head, mask_offset, scale, int8_mode);
}
}
// clang-format on
#define INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(T) \
template void invokeTransposeAttentionOutRemovePadding(T* src, \
T* dst, \
const int valid_word_num, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int* mask_offset, \
const float* scale, \
const int int8_mode, \
cudaStream_t stream)
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(float);
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING
template<typename T>
struct Vec_t {
static constexpr int size = 0;
};
template<>
struct Vec_t<float> {
using Type = float2;
static constexpr int size = 2;
#ifdef ENABLE_FP8
using QuantizedType = fp8_2_t;
#endif
};
template<>
struct Vec_t<half> {
using Type = uint32_t;
static constexpr int size = 2;
#ifdef ENABLE_FP8
using QuantizedType = fp8_2_t;
#endif
};
#ifdef ENABLE_BF16
template<>
struct Vec_t<__nv_bfloat16> {
using Type = __nv_bfloat162;
static constexpr int size = 2;
#ifdef ENABLE_FP8
using QuantizedType = fp8_2_t;
#endif
};
#endif
// Multiple calls of reinterpret_cast.
template <typename type_in, typename type_out>
inline __device__ type_out* reinterpret_ptr(void* ptr, size_t offset)
{
return reinterpret_cast<type_out*>(reinterpret_cast<type_in*>(ptr) + offset);
}
// Bandwidth-bound kernel by reading cos/sin coefficients from global memory (pre-computed and saved as weights).
template<typename T, typename Tcache, bool PREFIX_PROMPT, bool USE_PAGED_FMHA, RopeStyle ROPE_STYLE>
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam param,
T* QKV,
void* QuantizedQKV,
const int* position_ids,
const T* __restrict qkv_bias,
const int* padding_offset,
const int* cu_seqlens,
const int batch_size,
const int seq_len,
const int head_num,
const int head_num_kv,
const int size_per_head,
RopeConfig rope_config,
const bool use_logn_attn,
bool store_qkv,
bool store_q,
bool store_kv,
bool store_cache)
{
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
// For q and k, also apply the rotary embedding.
// When we pass prefix prompt, this kernel also concatenate the prefix prompt and key/value along
// seq_len dimension like [prompt, key/value].
// So, the final shape of q is same ([batch_size, head_num, seq_len, size_per_head]), but
// the shapes of key and values become [batch_size, head_num, max_prefix_prompt_length + seq_len, size_per_head].
// NOTE: QKV src shape (batch_size, seq_len, 3, head_num, size_per_head)
// QKV dst shape (3, batch_size, head_num, seq_len, size_per_head)
extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type
static constexpr bool ENABLE_8BITS_CACHE = sizeof(Tcache) == 1;
#ifdef ENABLE_FP8
// Quantized output only supports fp8 currently.
using QuantizedEltType = __nv_fp8_e4m3;
using QuantizedVecType = typename Vec_t<T>::QuantizedType;
#endif
constexpr int vec_size = Vec_t<T>::size;
using Vec_t = typename Vec_t<T>::Type;
const int token_idx = blockIdx.x;
const int token_padding_offset = padding_offset == nullptr ? 0 : padding_offset[token_idx];
const int tgt_token_idx = token_idx + token_padding_offset;
const int batch_idx = tgt_token_idx / seq_len;
const int seq_idx = tgt_token_idx % seq_len;
const int head_idx = blockIdx.y;
const int tidx = threadIdx.x;
const int total_seq_len = param.max_prefix_prompt_length + seq_len;
if (tidx * vec_size >= size_per_head) {
return;
}
const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0;
const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
const int n = head_num * size_per_head;
const int kv_n = head_num_kv * size_per_head; // MQA
// the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len)
// and Q [0..seq_len)
// Note: if !PREFIX_PROMPT, max_pp_len = 0, so it's no-op
const int dst_kv_seq_idx = seq_idx + prefix_prompt_length;
// NOTE: q has seq len excluding prefix prompt
// src QKV: [batch, time, 3, head, hidden]
const int src_q_idx = token_idx * (n + 2 * kv_n) + hidden_idx;
const int src_k_idx = token_idx * (n + 2 * kv_n) + hidden_idx + n;
const int src_v_idx = token_idx * (n + 2 * kv_n) + hidden_idx + kv_n + n;
Vec_t q, k, v;
q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
if (head_idx < head_num_kv) {
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
}
if (qkv_bias) {
Vec_t q_bias, k_bias, v_bias;
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
q = add(q, q_bias);
if (head_idx < head_num_kv) {
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + n]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + n + kv_n]);
k = add(k, k_bias);
v = add(v, v_bias);
}
}
int position_id = -1;
if (rope_config.style == RopeStyle::Mrope) {
int rope_dim = rope_config.mrope_dim1 + rope_config.mrope_dim2 + rope_config.mrope_dim3;
int now_idx = tidx % rope_dim, now_dim = 0;
if (now_idx >= rope_config.mrope_dim1 + rope_config.mrope_dim2) {
now_dim = 2;
} else if (now_idx >= rope_config.mrope_dim1) {
now_dim = 1;
}
position_id = position_ids[token_idx * rope_config.index_factor + now_dim];
} else if (position_ids) {
position_id = position_ids[token_idx * rope_config.index_factor];
}
const int pre_len = cu_seqlens[batch_idx];
const int input_len = cu_seqlens[batch_idx + 1] - pre_len;
context_rope<T, Vec_t, ROPE_STYLE>(
rope_config,
q,
k,
reinterpret_cast<T*>(smem_),
tidx,
seq_idx,
position_id,
seq_len,
input_len,
PREFIX_PROMPT,
prefix_prompt_length,
param.count_length);
if (use_logn_attn) {
logn_attention(q, seq_idx, rope_config.max_pos);
}
__syncthreads();
if (store_qkv) {
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
if (head_idx < head_num_kv) {
#ifdef ENABLE_FP8
if (QuantizedQKV != nullptr) {
// use 1.0f scale currently for qkv input of FP8 FMHA.
convert_to_fp8(reinterpret_cast<QuantizedVecType*>(
reinterpret_cast<QuantizedEltType*>(QuantizedQKV) + src_k_idx),
k);
convert_to_fp8(reinterpret_cast<QuantizedVecType*>(
reinterpret_cast<QuantizedEltType*>(QuantizedQKV) + src_v_idx),
v);
}
#endif
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
}
#ifdef ENABLE_FP8
if (QuantizedQKV != nullptr) {
size_t dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len
+ seq_idx * size_per_head + tidx * vec_size;
if constexpr (USE_PAGED_FMHA) {
dest_q_idx = (pre_len + seq_idx) * size_per_head * head_num
+ head_idx * size_per_head + tidx * vec_size;
}
*reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
QuantizedVecType* quantized_q_ptr = USE_PAGED_FMHA
? reinterpret_ptr<QuantizedEltType, QuantizedVecType>(q_buf, dest_q_idx)
: reinterpret_ptr<QuantizedEltType, QuantizedVecType>(QuantizedQKV, src_q_idx);
convert_to_fp8(quantized_q_ptr, q);
}
#endif
}
if (store_q) {
size_t dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len
+ seq_idx * size_per_head + tidx * vec_size;
if constexpr (USE_PAGED_FMHA) {
dest_q_idx = (pre_len + seq_idx) * size_per_head * head_num
+ head_idx * size_per_head + tidx * vec_size;
}
*reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
}
if (store_kv) {
const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * head_num_kv
+ head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head
+ tidx * vec_size;
if (head_idx < head_num_kv) {
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) = k;
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) = v;
}
}
if (store_cache) {
if (head_idx < head_num_kv) {
KVBlockArray kv_block_array = param.kv_block_array;
Tcache* k_cache = reinterpret_cast<Tcache*>(kv_block_array.getKBlockPtr(batch_idx, dst_kv_seq_idx));
Tcache* v_cache = reinterpret_cast<Tcache*>(kv_block_array.getVBlockPtr(batch_idx, dst_kv_seq_idx));
if constexpr (ENABLE_8BITS_CACHE) {
float* k_scale_ptr = reinterpret_cast<float*>(kv_block_array.getKScalePtr(batch_idx, dst_kv_seq_idx));
float* v_scale_ptr = reinterpret_cast<float*>(kv_block_array.getVScalePtr(batch_idx, dst_kv_seq_idx));
const int inBlockIdx = kv_block_array.getKVLocalIdx(dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size);
const int inScaleIdx = kv_block_array.getKVScaleLocalIdx(dst_kv_seq_idx, head_idx);
__shared__ float s_max[2];
if constexpr (std::is_same<Tcache, int8_t>::value) {
float local_max[2];
local_max[0] = vector_abs_max(k);
local_max[1] = vector_abs_max(v);
blockReduceMaxV2<float, 2>(local_max);
if (threadIdx.x == 0) {
s_max[0] = local_max[0];
s_max[1] = local_max[1];
}
} else {
s_max[0] = float(1 << (8 - 1));
s_max[1] = float(1 << (8 - 1));
}
__syncthreads();
store_8bits_kv_cache_vec(k_cache, k, inBlockIdx, float(1 << (8 - 1)) / s_max[0]);
store_8bits_kv_cache_vec(v_cache, v, inBlockIdx, float(1 << (8 - 1)) / s_max[1]);
if (tidx == 0) {
*reinterpret_cast<float*>(&k_scale_ptr[inScaleIdx]) = s_max[0] / float(1 << (8 - 1));
*reinterpret_cast<float*>(&v_scale_ptr[inScaleIdx]) = s_max[1] / float(1 << (8 - 1));
}
} else {
const int inBlockIdx = kv_block_array.getKVLocalIdx(dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size);
*reinterpret_cast<Vec_t*>(&k_cache[inBlockIdx]) = k;
*reinterpret_cast<Vec_t*>(&v_cache[inBlockIdx]) = v;
}
}
}
}
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam* param_ptr,
T* QKV,
void* QuantizedQKV,
const int* position_ids,
const T* qkv_bias,
const int* padding_offset,
const int* cu_seqlens,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int head_num_kv,
const int size_per_head,
const RopeConfig rope_config,
const bool use_logn_attn,
const float* scale,
const int int8_mode,
const bool use_paged_fmha,
const bool store_qkv,
const bool store_q,
const bool store_kv,
const bool store_cache,
cudaStream_t stream)
{
auto ¶m = *param_ptr;
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(token_num, head_num);
size_t smem_size = rope_config.style == RopeStyle::No ? 0 : 2 * rope_config.dim * sizeof(T);
FT_SWITCH(param.max_prefix_prompt_length != 0, PREFIX_PROMPT, [&]{
FT_SWITCH(use_paged_fmha, USE_PAGED_FMHA, [&]{
FT_SWITCH_KV_CACHE_TYPE_CASE(param.kv_block_array.cache_type, Tcache, [&]{
FT_ROPE_SWITCH(rope_config.style, ROPE_STYLE, [&]{
add_fusedQKV_bias_transpose_kernel<T, Tcache, PREFIX_PROMPT, USE_PAGED_FMHA, ROPE_STYLE>
<<<grid, block, smem_size, stream>>>(
q_buf,
k_buf,
v_buf,
param,
QKV,
QuantizedQKV,
position_ids,
qkv_bias,
padding_offset,
cu_seqlens,
batch_size,
seq_len,
head_num,
head_num_kv,
size_per_head,
rope_config,
use_logn_attn,
store_qkv,
store_q,
store_kv,
store_cache);
});
});
});
});
}
template<typename T, typename Tcache>
__global__ void load_prefix_KVCache_kernel(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam param,
const int seq_len,
const int head_num,
const int head_num_kv,
const int size_per_head)
{
static constexpr bool ENABLE_8BITS_CACHE = sizeof(Tcache) == 1;
constexpr int vec_size = Vec_t<T>::size;
using Vec_t = typename Vec_t<T>::Type;
const int head_idx = blockIdx.y;
const int tidx = threadIdx.x;
const int total_seq_len = param.max_prefix_prompt_length + seq_len;
if (tidx * vec_size >= size_per_head) {
return;
}
// NOTE: blockIdx.x < batch_size * param.max_prefix_prompt_length really handles prefix prompts
if (head_idx < head_num_kv) {
const int prompt_batch_idx = blockIdx.x / param.max_prefix_prompt_length;
const int prompt_seq_idx = blockIdx.x % param.max_prefix_prompt_length;
const int prompt_length = param.d_prefix_prompt_lengths[prompt_batch_idx];
if (prompt_seq_idx < prompt_length) {
const int dest_kv_idx = prompt_batch_idx * size_per_head * total_seq_len * head_num_kv
+ head_idx * size_per_head * total_seq_len + prompt_seq_idx * size_per_head
+ tidx * vec_size;
if (param.kv_block_array.mMaxSeqs > 0) {
Tcache* k_cache = reinterpret_cast<Tcache*>(param.kv_block_array.getKBlockPtr(prompt_batch_idx, prompt_seq_idx));
Tcache* v_cache = reinterpret_cast<Tcache*>(param.kv_block_array.getVBlockPtr(prompt_batch_idx, prompt_seq_idx));
const int inBlockIdx = param.kv_block_array.getKVLocalIdx(
prompt_seq_idx, head_idx, size_per_head, tidx * vec_size);
if constexpr (ENABLE_8BITS_CACHE) {
float* k_scale_ptr = reinterpret_cast<float*>(param.kv_block_array.getKScalePtr(prompt_batch_idx, prompt_seq_idx));
float* v_scale_ptr = reinterpret_cast<float*>(param.kv_block_array.getVScalePtr(prompt_batch_idx, prompt_seq_idx));
int inScaleIdx = param.kv_block_array.getKVScaleLocalIdx(prompt_seq_idx, head_idx);
load_8bits_kv_cache_vec(reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]), k_cache, inBlockIdx, k_scale_ptr[inScaleIdx]);
load_8bits_kv_cache_vec(reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]), v_cache, inBlockIdx, v_scale_ptr[inScaleIdx]);
} else {
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) =
*reinterpret_cast<const Vec_t*>(&k_cache[inBlockIdx]);
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) =
*reinterpret_cast<const Vec_t*>(&v_cache[inBlockIdx]);
}
}
}
}
}
template<typename T>
void invokeLoadPrefixKVCache(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam* param_ptr,
const int batch_size,
const int seq_len,
const int head_num,
const int head_num_kv,
const int size_per_head,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
auto ¶m = *param_ptr;
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(batch_size * param.max_prefix_prompt_length, head_num);
FT_SWITCH_KV_CACHE_TYPE_CASE(param.kv_block_array.cache_type, Tcache, [&]{
load_prefix_KVCache_kernel<T, Tcache>
<<<grid, block, 0, stream>>>(
q_buf,
k_buf,
v_buf,
param,
seq_len,
head_num,
head_num_kv,
size_per_head);
});
}
template<typename T>
__global__ void SplitQKV_kernel(T* q_buf,
T* k_buf,
T* v_buf,
T* QKV,
const int token_num,
const int head_num,
const int head_num_kv,
const int size_per_head)
{
// QKV: [token_num, 3, n]
// q_buf, k_buf, v_buf: [token_num, head_num, size_per_head] [token_num, head_num_kv, size_per_head] * 2
// grid(token_num, head_num + 2 * head_num_kv)
// block(size_per_head)
const int token_idx = blockIdx.x;
const int head_idx = blockIdx.y;
const int size_id = threadIdx.x;
if (size_id >= size_per_head) {
return;
}
const int qkv_offset =
token_idx * (head_num + head_num_kv * 2) * size_per_head + head_idx * size_per_head + size_id;
T val = ldg(&QKV[qkv_offset]);
if (head_idx < head_num) {
q_buf[token_idx * head_num * size_per_head + head_idx * size_per_head + size_id] = val;
}
else if (head_idx < head_num + head_num_kv) {
k_buf[token_idx * head_num_kv * size_per_head + (head_idx - head_num) * size_per_head + size_id] = val;
}
else {
v_buf[token_idx * head_num_kv * size_per_head + (head_idx - head_num - head_num_kv) * size_per_head + size_id] =
val;
}
}
template<typename T>
void invokeSplitQKV(T* q_buf,
T* k_buf,
T* v_buf,
T* QKV,
const int token_num,
const int head_num,
const int head_num_kv,
const int size_per_head,
cudaStream_t stream)
{
dim3 block(size_per_head);
dim3 grid(token_num, head_num + 2 * head_num_kv);
SplitQKV_kernel<<<grid, block, 0, stream>>>(
q_buf, k_buf, v_buf, QKV, token_num, head_num, head_num_kv, size_per_head);
}
#define INSTANTIATESPLITQKV(T) \
template void invokeSplitQKV(T* q_buf, \
T* k_buf, \
T* v_buf, \
T* QKV, \
const int token_num, \
const int head_num, \
const int head_num_kv, \
const int size_per_head, \
cudaStream_t stream)
INSTANTIATESPLITQKV(float);
INSTANTIATESPLITQKV(half);
#ifdef ENABLE_BF16
INSTANTIATESPLITQKV(__nv_bfloat16);
#endif
#undef INSTANTIATESPLITQKV
#define INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(T) \
template void invokeAddFusedQKVBiasTranspose(T* q_buf, \
T* k_buf, \
T* v_buf, \
PrefixPromptBatchWeightsParam* param, \
T* QKV, \
void* QuantizedQKV, \
const int* position_ids, \
const T* qkv_bias, \
const int* padding_offset, \
const int* cu_seqlens, \
const int batch_size, \
const int seq_len, \
const int token_num, \
const int head_num, \
const int head_num_kv, \
const int size_per_head, \
const RopeConfig rope_config, \
const bool use_logn_attn, \
const float* scale, \
const int int8_mode, \
const bool use_paged_fmha, \
const bool store_qkv, \
const bool store_q, \
const bool store_kv, \
const bool store_cache, \
cudaStream_t stream)
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(float);
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(half);
#ifdef ENABLE_BF16
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(__nv_bfloat16);
#endif
#undef INSTANTIATEADDFUSEDQKVBIASTRANSPOSE
#define INSTANTIATEINVOKELOADPREFIXKVCACHE(T) \
template void invokeLoadPrefixKVCache(T* q_buf, \
T* k_buf, \
T* v_buf, \
PrefixPromptBatchWeightsParam* param, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int head_num_kv, \
const int size_per_head, \
const float* scale, \
const int int8_mode, \
cudaStream_t stream)
INSTANTIATEINVOKELOADPREFIXKVCACHE(float);
INSTANTIATEINVOKELOADPREFIXKVCACHE(half);
#ifdef ENABLE_BF16
INSTANTIATEINVOKELOADPREFIXKVCACHE(__nv_bfloat16);
#endif
#undef INSTANTIATEINVOKELOADPREFIXKVCACHE
template<typename T>
__global__ void transpose_4d(T* dst,
T* src,
const int dim0,
const int dim1,
const int dim2,
const int dim3,
const int dim0_leading_dim,
const int ite)
{
// transpose from [dim0, dim1, dim2, dim3] to [dim2, X, dim1, dim3]
// where the dimension of X is dim0_leading_dim, and offset is ite * dim0
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * dim3; i += blockDim.x * gridDim.x) {
int index = i;
const int d3 = index % dim3;
index = (index - d3) / dim3;
const int d2 = index % dim2;
index = (index - d2) / dim2;
const int d1 = index % dim1;
index = (index - d1) / dim1;
const int d0 = index % dim0;
index = (index - d0) / dim0;
dst[d2 * dim0_leading_dim * dim1 * dim3 + (d0 + dim0 * ite) * dim1 * dim3 + d1 * dim3 + d3] = src[i];
}
}
template<>
__global__ void transpose_4d(half* dst,
half* src,
const int dim0,
const int dim1,
const int dim2,
const int dim3,
const int dim0_leading_dim,
const int ite)
{
half2* dst_ptr = (half2*)dst;
half2* src_ptr = (half2*)src;
const int half_dim3 = dim3 / 2;
// transpose from [dim0, dim1, dim2, half_dim3] to [dim2, dim0, dim1, half_dim3]
// where the dimension of X is dim0_leading_dim, and offset is ite * dim0
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * half_dim3;
i += blockDim.x * gridDim.x) {
int index = i;
const int d3 = index % half_dim3;
index = (index - d3) / half_dim3;
const int d2 = index % dim2;
index = (index - d2) / dim2;
const int d1 = index % dim1;
index = (index - d1) / dim1;
const int d0 = index % dim0;
index = (index - d0) / dim0;
dst_ptr[d2 * dim0_leading_dim * dim1 * half_dim3 + (d0 + dim0 * ite) * dim1 * half_dim3 + d1 * half_dim3 + d3] =
src_ptr[i];
}
}
template<typename T>
void invokeTranspose4d(T* dst,
T* src,
const int local_batch_size,
const int seq_len,
const int size_per_head,
const int local_hidden_units,
const int local_head_num,
const int batch_size,
const int ite,
cudaStream_t stream)
{
transpose_4d<<<local_batch_size * seq_len * local_hidden_units / 512, 512 / (4 / (sizeof(T))), 0, stream>>>(
dst, src, local_batch_size, local_head_num, seq_len, size_per_head, batch_size, ite);
}
#define INSTANTIATETRANSPOSE4D(T) \
template void invokeTranspose4d(T* dst, \
T* src, \
const int local_batch_size, \
const int seq_len, \
const int size_per_head, \
const int local_hidden_units, \
const int local_head_num, \
const int batch_size, \
const int ite, \
cudaStream_t stream)
INSTANTIATETRANSPOSE4D(float);
INSTANTIATETRANSPOSE4D(half);
#undef INSTANTIATETRANSPOSE4D
void invokeGetSkipLength(int* skip_length, int* prefix_lengths, int batch_size, cudaStream_t stream)
{
if (!prefix_lengths || batch_size == 0) {
return;
}
dim3 grid(1);
dim3 block(1);
getSkipLength<<<grid, block, 0, stream>>>(skip_length, prefix_lengths, batch_size);
}
template<typename T>
__global__ void addRelativeAttentionBias(
T* qk_buf, const T* relative_attention_bias, const int batch_size, const int head_num, const int seq_len)
{
for (int i = threadIdx.x; i < batch_size * seq_len; i += blockDim.x) {
int batch_id = i / seq_len;
int seq_id = i % seq_len;
const int bias_index = blockIdx.x * seq_len + seq_id;
const int qk_index = batch_id * gridDim.x * seq_len + bias_index;
qk_buf[qk_index] = add(qk_buf[qk_index], relative_attention_bias[bias_index]);
}
}
template<typename T>
void invokeAddRelativeAttentionBias(T* qk_buf,
const T* relative_attention_bias,
const int batch_size,
const int head_num,
const int seq_len,
cudaStream_t stream)
{
// qk_buf: [batch_size, head_num, seq_len, seq_len]
// relative_attention_bias: [1, head_num, seq_len, seq_len]
dim3 grid(head_num * seq_len);
dim3 block(512);
using T2 = typename TypeConverter<T>::Type;
#ifdef ENABLE_BF16
const bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (seq_len % 2 == 0);
#else
const bool is_half2 = (std::is_same<T, half>::value) && (seq_len % 2 == 0);
#endif
if (is_half2) {
addRelativeAttentionBias<T2><<<grid, block, 0, stream>>>(
(T2*)qk_buf, (const T2*)relative_attention_bias, batch_size, head_num, seq_len / 2);
}
else {
addRelativeAttentionBias<<<grid, block, 0, stream>>>(
qk_buf, relative_attention_bias, batch_size, head_num, seq_len);
}
}
#define INSTANTIATEADDRELATIVEATTENTIONBIAS(T) \
template void invokeAddRelativeAttentionBias(T* qk_buf, \
const T* relative_attention_bias, \
const int batch_size, \
const int head_num, \
const int seq_len, \
cudaStream_t stream)
INSTANTIATEADDRELATIVEATTENTIONBIAS(float);
INSTANTIATEADDRELATIVEATTENTIONBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEADDRELATIVEATTENTIONBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEADDRELATIVEATTENTIONBIAS
/******************* invokeAddHead3SizeQKVBias ***********************/
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, 3*batch);
// block(num_head * size_per_head)
template<typename T>
__global__ void add_head3Size_QKV_bias(const T* mm_qkv,
const T* bias_qkv,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
T* buf_ptr;
int qkv_id = blockIdx.z / batch;
if (qkv_id == 0) {
buf_ptr = q_buf_;
}
else if (qkv_id == 1) {
buf_ptr = k_buf_;
}
else {
buf_ptr = v_buf_;
}
const int batch_id = blockIdx.z % batch;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
const T bias = ldg(bias_qkv + bias_idx);
const int input_idx =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx;
T tmp = mm_qkv[input_idx] + bias;
int target_id = (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
;
buf_ptr[target_id] = tmp;
}
// for float2, size_per_head /= 2
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, 3*batch);
// block(num_head * size_per_head)
template<>
__global__ void add_head3Size_QKV_bias(const float2* mm_qkv,
const float2* bias_qkv,
float2* q_buf_,
float2* k_buf_,
float2* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
float2* buf_ptr;
int qkv_id = blockIdx.z / batch;
if (qkv_id == 0) {
buf_ptr = q_buf_;
}
else if (qkv_id == 1) {
buf_ptr = k_buf_;
}
else {
buf_ptr = v_buf_;
}
const int batch_id = blockIdx.z % batch;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
const float2 bias = ldg(bias_qkv + bias_idx);
const int input_idx =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx;
float2 tmp = mm_qkv[input_idx];
tmp.x += bias.x;
tmp.y += bias.y;
int target_id = (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
;
buf_ptr[target_id] = tmp;
}
// for half2, size_per_head /= 2
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, batch);
// block(num_head * size_per_head)
template<>
__global__ void add_head3Size_QKV_bias(const half2* mm_qkv,
const half2* bias_qkv,
half2* q_buf_,
half2* k_buf_,
half2* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
const int batch_id = blockIdx.z;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int input_offset =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head;
const int target_id =
(((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
int qkv_id = 0;
int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
half2 bias = __ldg(bias_qkv + bias_idx);
int input_idx = input_offset + bias_idx;
half2 tmp = mm_qkv[input_idx];
tmp = __hadd2(tmp, bias);
q_buf_[target_id] = tmp;
qkv_id = 1;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = __ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = __hadd2(tmp, bias);
k_buf_[target_id] = tmp;
qkv_id = 2;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = __ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = __hadd2(tmp, bias);
v_buf_[target_id] = tmp;
}
#ifdef ENABLE_BF16
template<>
__global__ void add_head3Size_QKV_bias(const __nv_bfloat162* mm_qkv,
const __nv_bfloat162* bias_qkv,
__nv_bfloat162* q_buf_,
__nv_bfloat162* k_buf_,
__nv_bfloat162* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
const int batch_id = blockIdx.z;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int input_offset =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head;
const int target_id =
(((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
int qkv_id = 0;
int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
__nv_bfloat162 bias = ldg(bias_qkv + bias_idx);
int input_idx = input_offset + bias_idx;
__nv_bfloat162 tmp = mm_qkv[input_idx];
tmp = bf16hadd2(tmp, bias);
q_buf_[target_id] = tmp;
qkv_id = 1;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = bf16hadd2(tmp, bias);
k_buf_[target_id] = tmp;
qkv_id = 2;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = bf16hadd2(tmp, bias);
v_buf_[target_id] = tmp;
}
#endif
template<typename T>
void invokeAddHead3SizeQKVBias(const T* mm_qkv,
const T* bias_qkv,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head,
cudaStream_t stream)
{
if (std::is_same<T, float>::value) {
dim3 grid(window_len, window_num, 3 * batch);
dim3 block(num_head * size_per_head);
if (block.x < 1024) {
add_head3Size_QKV_bias<<<grid, block, 0, stream>>>(
mm_qkv, bias_qkv, q_buf_, k_buf_, v_buf_, batch, window_num, window_len, num_head, size_per_head);
}
else if ((block.x % 2 == 0) && (block.x / 2 < 1024)) {
block.x /= 2;
add_head3Size_QKV_bias<<<grid, block, 0, stream>>>((const float2*)mm_qkv,
(const float2*)bias_qkv,
(float2*)q_buf_,
(float2*)k_buf_,
(float2*)v_buf_,
batch,
window_num,
window_len,
num_head,
size_per_head / 2);
}
else {
printf("[ERROR][invokeAddHead3SizeQKVBias] unsupported block.x!\n");
exit(-1);
}
}
#ifdef ENABLE_BF16
else if (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) {
#else
else if (std::is_same<T, half>::value) {
#endif
dim3 grid(window_len, window_num, batch);
dim3 block(num_head * size_per_head / 2);
using T2 = typename TypeConverter<T>::Type; // half2 or bfloat16
if (block.x > 1024) {
printf("[ERROR][invokeAddHead3SizeQKVBias] block.x > 1024!\n");
exit(-1);
}
add_head3Size_QKV_bias<<<grid, block, 0, stream>>>((const T2*)mm_qkv,
(const T2*)bias_qkv,
(T2*)q_buf_,
(T2*)k_buf_,
(T2*)v_buf_,
batch,
window_num,
window_len,
num_head,
size_per_head / 2);
}
}
#define INSTANTIATEADDHEAD3SIZEQKVBIAS(T) \
template void invokeAddHead3SizeQKVBias<T>(const T* mm_qkv, \
const T* bias_qkv, \
T* q_buf_, \
T* k_buf_, \
T* v_buf_, \
const int batch, \
const int window_num, \
const int window_len, \
const int num_head, \
const int size_per_head, \
cudaStream_t stream)
INSTANTIATEADDHEAD3SIZEQKVBIAS(float);
INSTANTIATEADDHEAD3SIZEQKVBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEADDHEAD3SIZEQKVBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEADDHEAD3SIZEQKVBIAS
#if USING_CUDA
/******************* invokeMaskedSoftMaxWithRelPosBias ***********************/
// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T>
__global__ void softmax_withRelPosBias_element1_kernel(T* qk_buf,
const T* attn_mask,
const T* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
const int window_len_x_window_len,
const float qk_scale)
{
bool qual = threadIdx.x < window_len;
for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
float tmp = -1e20f;
__shared__ float s_mean, s_max;
int64_t qk_offset;
if (qual) {
const int offset_in_window = window_id * window_len + threadIdx.x;
qk_offset = (blockIdx.z * gridDim.y + blockIdx.y) * static_cast<int64_t>(window_len_x_window_len)
+ offset_in_window;
const int relative_pos_bias_offset = (blockIdx.y % num_head) * window_len_x_window_len + offset_in_window;
float mask_val =
(attn_mask == nullptr) ?
0.0f :
static_cast<float>(
ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window)));
tmp = qk_scale * static_cast<float>(qk_buf[qk_offset]) + mask_val
+ static_cast<float>(ldg(relative_pos_bias + relative_pos_bias_offset));
}
float max_val = blockReduceMax<float>(tmp);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float qk_tmp = qual ? __expf(tmp - s_max) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
if (qual) {
qk_buf[qk_offset] = (T)(qk_tmp * s_mean);
}
}
}
// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len/2 + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T2, typename T>
__global__ void softmax_withRelPosBias_element2_kernel(T2* qk_buf,
const T2* attn_mask,
const T2* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
const int window_len_x_window_len,
const float qk_scale)
{
const int window_len_2 = window_len / 2;
const int tidx = threadIdx.x;
bool qual = tidx < window_len_2;
const T2 zero = {T(0.0f), T(0.0f)};
const int bdim = blockDim.x;
for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
float tmp = -1e20f;
__shared__ float s_mean, s_max;
int64_t qk_offset;
float2 local_qk_val;
T2 qk_val;
if (qual) {
const int offset_in_window = window_id * window_len + 2 * tidx;
qk_offset = ((blockIdx.z * gridDim.y + blockIdx.y) * static_cast<int64_t>(window_len_x_window_len)
+ offset_in_window)
/ 2;
const int relative_pos_bias_offset =
((blockIdx.y % num_head) * window_len_x_window_len + offset_in_window) / 2;
T2 mask_val =
(attn_mask == nullptr) ?
zero :
ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window) / 2);
qk_val = qk_buf[qk_offset];
local_qk_val.x = static_cast<float>(qk_val.x);
local_qk_val.y = static_cast<float>(qk_val.y);
const T2 bias_val = ldg(relative_pos_bias + relative_pos_bias_offset);
local_qk_val.x =
qk_scale * local_qk_val.x + static_cast<float>(mask_val.x) + static_cast<float>(bias_val.x);
local_qk_val.y =
qk_scale * local_qk_val.y + static_cast<float>(mask_val.y) + static_cast<float>(bias_val.y);
tmp = local_qk_val.x > local_qk_val.y ? local_qk_val.x : local_qk_val.y;
}
float max_val = bdim <= 32 ? warpReduceMax<float>(tmp) : blockReduceMax<float>(tmp);
if (tidx == 0) {
s_max = max_val;
}
__syncthreads();
local_qk_val.x = qual ? __expf(local_qk_val.x - s_max) : 0.0f;
local_qk_val.y = qual ? __expf(local_qk_val.y - s_max) : 0.0f;
float sum_val = bdim <= 32 ? warpReduceSum<float>(local_qk_val.x + local_qk_val.y) :
blockReduceSum<float>(local_qk_val.x + local_qk_val.y);
if (tidx == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
if (qual) {
local_qk_val.x = local_qk_val.x * s_mean;
local_qk_val.y = local_qk_val.y * s_mean;
qk_val.x = T(local_qk_val.x);
qk_val.y = T(local_qk_val.y);
qk_buf[qk_offset] = qk_val;
}
}
}
// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len/4 + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T4, typename T>
__global__ void softmax_withRelPosBias_element4_kernel(T4* qk_buf,
const T4* attn_mask,
const T4* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
const int window_len_x_window_len,
const float qk_scale)
{
const int window_len_4 = window_len / 4;
const int tidx = threadIdx.x;
bool qual = tidx < window_len_4;
const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)};
const int bdim = blockDim.x;
for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
float tmp = -1e20f;
__shared__ float s_mean, s_max;
int64_t qk_offset;
float4 local_qk_val;
T4 qk_val;
if (qual) {
const int offset_in_window = window_id * window_len + 4 * tidx;
qk_offset = ((blockIdx.z * gridDim.y + blockIdx.y) * static_cast<int64_t>(window_len_x_window_len)
+ offset_in_window)
/ 4;
const int relative_pos_bias_offset =
((blockIdx.y % num_head) * window_len_x_window_len + offset_in_window) / 4;
T4 mask_val = (attn_mask == nullptr) ?
zero :
attn_mask[((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window) / 4];
qk_val = qk_buf[qk_offset];
local_qk_val.x = static_cast<float>(qk_val.x);
local_qk_val.y = static_cast<float>(qk_val.y);
local_qk_val.z = static_cast<float>(qk_val.z);
local_qk_val.w = static_cast<float>(qk_val.w);
const T4 bias_val = relative_pos_bias[relative_pos_bias_offset];
local_qk_val.x =
qk_scale * local_qk_val.x + static_cast<float>(mask_val.x) + static_cast<float>(bias_val.x);
local_qk_val.y =
qk_scale * local_qk_val.y + static_cast<float>(mask_val.y) + static_cast<float>(bias_val.y);
local_qk_val.z =
qk_scale * local_qk_val.z + static_cast<float>(mask_val.z) + static_cast<float>(bias_val.z);
local_qk_val.w =
qk_scale * local_qk_val.w + static_cast<float>(mask_val.w) + static_cast<float>(bias_val.w);
tmp = local_qk_val.x > local_qk_val.y ? local_qk_val.x : local_qk_val.y;
tmp = tmp > local_qk_val.z ? tmp : local_qk_val.z;
tmp = tmp > local_qk_val.w ? tmp : local_qk_val.w;
}
float max_val = bdim <= 32 ? warpReduceMax<float>(tmp) : blockReduceMax<float>(tmp);
if (tidx == 0) {
s_max = max_val;
}
__syncthreads();
local_qk_val.x = qual ? __expf(local_qk_val.x - s_max) : 0.0f;
local_qk_val.y = qual ? __expf(local_qk_val.y - s_max) : 0.0f;
local_qk_val.z = qual ? __expf(local_qk_val.z - s_max) : 0.0f;
local_qk_val.w = qual ? __expf(local_qk_val.w - s_max) : 0.0f;
float sum_val = bdim <= 32 ?
warpReduceSum<float>(local_qk_val.x + local_qk_val.y + local_qk_val.z + local_qk_val.w) :
blockReduceSum<float>(local_qk_val.x + local_qk_val.y + local_qk_val.z + local_qk_val.w);
if (tidx == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
if (qual) {
local_qk_val.x = local_qk_val.x * s_mean;
local_qk_val.y = local_qk_val.y * s_mean;
local_qk_val.z = local_qk_val.z * s_mean;
local_qk_val.w = local_qk_val.w * s_mean;
qk_val.x = T(local_qk_val.x);
qk_val.y = T(local_qk_val.y);
qk_val.z = T(local_qk_val.z);
qk_val.w = T(local_qk_val.w);
qk_buf[qk_offset] = qk_val;
}
}
}
template<typename T>
void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf,
const T* attn_mask,
const T* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
float qk_scale,
cudaStream_t stream)
{
const int word_per_thread = 1;
dim3 grid((window_len + word_per_thread - 1) / word_per_thread, window_num * num_head, batch_size);
if ((window_len % 4 == 0) && window_len / 4 >= 32) {
dim3 block((window_len / 4 + 31) / 32 * 32);
if (std::is_same<T, float>::value) {
softmax_withRelPosBias_element4_kernel<float4, float>
<<<grid, block, 0, stream>>>((float4*)qk_buf,
(const float4*)attn_mask,
(const float4*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
else if (std::is_same<T, half>::value) {
softmax_withRelPosBias_element4_kernel<half4, half>
<<<grid, block, 0, stream>>>((half4*)qk_buf,
(const half4*)attn_mask,
(const half4*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#ifdef ENABLE_BF16
else {
dim3 block((window_len + 31) / 32 * 32);
softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
attn_mask,
relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#endif
}
else if (window_len % 2 == 0) {
dim3 block((window_len / 2 + 31) / 32 * 32);
if (std::is_same<T, float>::value) {
softmax_withRelPosBias_element2_kernel<float2, float>
<<<grid, block, 0, stream>>>((float2*)qk_buf,
(const float2*)attn_mask,
(const float2*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
else if (std::is_same<T, half>::value) {
softmax_withRelPosBias_element2_kernel<half2, half>
<<<grid, block, 0, stream>>>((half2*)qk_buf,
(const half2*)attn_mask,
(const half2*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#ifdef ENABLE_BF16
else {
dim3 block((window_len + 31) / 32 * 32);
softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
attn_mask,
relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#endif
}
else {
dim3 block((window_len + 31) / 32 * 32);
softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
attn_mask,
relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
}
#define INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(T) \
template void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, \
const T* attn_mask, \
const T* relative_pos_bias, \
const int batch_size, \
const int num_head, \
const int window_num, \
const int window_len, \
const float qk_scale, \
cudaStream_t stream)
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(float);
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS
#endif
} // namespace rtp_llm