maga_transformer/cpp/kernels/gpt_kernels.cu (1,633 lines of code) (raw):
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <type_traits>
#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh"
#include "maga_transformer/cpp/cuda/cuda_fp8_utils.h"
#if USING_CUDA
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#endif
#include "maga_transformer/cpp/kernels/gpt_kernels.h"
#include "maga_transformer/cpp/cuda/memory_utils.h"
namespace rtp_llm {
template<typename T, bool USE_POS_EMB, bool USE_TYPE_ID_EMB, bool USE_MASK>
__global__ void embedding_lookup_kernel(T* from_tensor,
const T* embedding_table,
double input_embedding_scalar,
const T* pos_table,
const T* type_table,
const int* input_ids,
const int* input_pos,
const int* input_type,
const int* input_mask,
const int token_num,
const int64_t hidden_units)
{
for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < (int64_t)(token_num * hidden_units);
index += blockDim.x * gridDim.x) {
const int64_t token_index = index / hidden_units;
const int64_t col_index = index % hidden_units;
const int input_id = input_ids[token_index];
T embedding = (T)0.0f;
T pos_embed = (T)0.0f;
T type_embed = (T)0.0f;
if constexpr(USE_POS_EMB) {
assert(pos_table != nullptr);
pos_embed = pos_table[input_pos[token_index] * hidden_units + col_index];
}
if constexpr(USE_TYPE_ID_EMB) {
assert(type_table != nullptr);
type_embed = type_table[input_type[token_index] * hidden_units + col_index];
}
if constexpr(USE_MASK) {
assert(input_mask != nullptr);
if (input_mask[token_index] == 0) {
from_tensor[index] = pos_embed + type_embed;
continue;
}
}
embedding = embedding_table[input_id * hidden_units + col_index];
// embedding *= input_embedding_scalar;
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
embedding *= __double2bfloat16(input_embedding_scalar);
} else if constexpr (std::is_same<T, __half>::value){
embedding *= static_cast<T>(input_embedding_scalar);
} else {
embedding *= input_embedding_scalar;
}
from_tensor[index] = embedding + pos_embed + type_embed;
}
}
#define INVOKE_WORD_EMBED_LOOKUP(USE_POS, USE_YPE, USE_MASK) \
embedding_lookup_kernel<T, USE_POS, USE_YPE, USE_MASK><<<grid, block, 0, stream>>>(from_tensor, \
embedding_table, \
input_embedding_scalar, \
pos_table, \
type_table, \
input_ids, \
input_pos, \
input_type, \
input_mask, \
token_num, \
hidden_units);
template<typename T>
void invokeEmebeddingLookup(T* from_tensor,
const T* embedding_table,
double input_embedding_scalar,
const T* pos_table,
const T* type_table,
const int* input_ids,
const int* input_pos,
const int* input_type,
const int* input_mask,
const int token_num,
const int hidden_units,
cudaStream_t stream)
{
dim3 grid(std::min(token_num, 65536));
dim3 block(std::min(hidden_units, 1024));
if (!pos_table) {
if (!type_table) {
if (!input_mask) {
INVOKE_WORD_EMBED_LOOKUP(false, false, false);
} else {
INVOKE_WORD_EMBED_LOOKUP(false, false, true);
}
} else {
if (!input_mask) {
INVOKE_WORD_EMBED_LOOKUP(false, true, false);
} else {
INVOKE_WORD_EMBED_LOOKUP(false, true, true);
}
}
} else {
if (!type_table) {
if (!input_mask) {
INVOKE_WORD_EMBED_LOOKUP(true, false, false);
} else {
INVOKE_WORD_EMBED_LOOKUP(true, false, true);
}
} else {
if (!input_mask) {
INVOKE_WORD_EMBED_LOOKUP(true, true, false);
} else {
INVOKE_WORD_EMBED_LOOKUP(true, true, true);
}
}
}
}
#undef INVOKE_WORD_EMBED_LOOKUP
// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T, bool OUTPUT_ID, int PROMPT_SRC>
__global__ void start_id_embedding_position_lookups_kernel(T* from_tensor,
int* output_ids,
const T* embedding_table,
const T* pos_table,
pPromptTuningParam<T> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int64_t hidden_units)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units;
index += blockDim.x * gridDim.x) {
// transpose the input_ids [batch, length] (part of [batch, max_length]) to output_ids [length, batch]
if (OUTPUT_ID && index < batch_size * max_length) {
// for p/prompt_tuning (have prompt templates like [input1, prompt1, input2, prompt2])
// we have to process it to like [input1, input2, prompt1, prompt2], and then remove the prompts during post
// processing
if (PROMPT_SRC > 0) {
if (index < batch_size) {
int no_prompt_output_seq_id = 0;
#pragma unroll 1
for (int seq_id = 0; seq_id < max_length; seq_id++) {
int current_input_id = input_ids[index * max_length + seq_id];
if (current_input_id < prompt_param.p_prompt_tuning_id_start) {
output_ids[no_prompt_output_seq_id * batch_size + index] = current_input_id;
no_prompt_output_seq_id++;
}
}
}
}
else {
const int seq_id = index % max_length;
const int batch_id = index / max_length;
if (seq_id < length) {
output_ids[seq_id * batch_size + batch_id] = input_ids[index];
}
}
}
// embedding lookup from word ids [batch, length] (part of [batch, max_length]) and [vocab, hidden] to generate
// embedding [batch, length, hidden]
const int word_index = index / hidden_units;
const int word_index_row = word_index / length; // batch_id
const int word_index_col = word_index % length;
const int real_word_index = word_index_row * max_length + word_index_col;
const int step = start_step + word_index % length;
const int col_index = index % hidden_units;
const int input_id = input_ids == nullptr ? real_word_index : input_ids[real_word_index];
const int prompt_id = input_id - prompt_param.p_prompt_tuning_id_start;
T embedding = (T)0.0f;
if (PROMPT_SRC > 0 && prompt_id >= 0) {
if (PROMPT_SRC == 1) {
// from loaded prompt embedding tables
embedding =
prompt_param.p_prompt_tuning_batch_weights[word_index_row][prompt_id * hidden_units + col_index];
}
else {
// from request prompt embedding
embedding =
prompt_param
.request_prompt_embedding[word_index_row * prompt_param.request_prompt_max_length * hidden_units
+ prompt_id * hidden_units + col_index];
}
}
else {
embedding = embedding_table[input_id * hidden_units + col_index];
}
T pos_embed = pos_table == nullptr ? (T)0.f : pos_table[(step - 1) * hidden_units + col_index];
from_tensor[index] = embedding + pos_embed;
}
}
#define WORD_POS_EMBEDDING_LOOPUP_KERNEL(OUTPUT_ID, PROMPT_SRC) \
start_id_embedding_position_lookups_kernel<T, OUTPUT_ID, PROMPT_SRC><<<grid, block, 0, stream>>>(from_tensor, \
output_ids, \
embedding_table, \
pos_table, \
prompt_param, \
input_ids, \
start_step, \
length, \
max_length, \
batch_size, \
hidden_units);
template<typename T>
void invokeInputIdsEmbeddingLookupPosEncoding(T* from_tensor,
int* output_ids,
const T* embedding_table, // can also be inputs_embeds
const T* pos_table,
pPromptTuningParam<T> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream)
{
dim3 grid(min(batch_size * length, 65536));
dim3 block(min(hidden_units, 512));
const bool has_output_ids = output_ids != nullptr;
RTP_LLM_CHECK(!(has_output_ids && input_ids == nullptr));
if (has_output_ids) {
if (prompt_param.use_request_p_prompt_embedding) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 2);
}
else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 1);
}
else {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 0);
}
}
else {
if (prompt_param.use_request_p_prompt_embedding) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 2);
}
else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 1);
}
else {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 0);
}
}
}
template void invokeInputIdsEmbeddingLookupPosEncoding(float* from_tensor,
int* output_ids,
const float* embedding_table,
const float* pos_table,
pPromptTuningParam<float> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
template void invokeInputIdsEmbeddingLookupPosEncoding(half* from_tensor,
int* output_ids,
const half* embedding_table,
const half* pos_table,
pPromptTuningParam<half> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeInputIdsEmbeddingLookupPosEncoding(__nv_bfloat16* from_tensor,
int* output_ids,
const __nv_bfloat16* embedding_table,
const __nv_bfloat16* pos_table,
pPromptTuningParam<__nv_bfloat16> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
#endif
#define INSTANTIATE_INVOKE_EMBEDDING_LOOKUP(T) \
template void invokeEmebeddingLookup( \
T* from_tensor, \
const T* embedding_table, \
double input_embedding_scalar, \
const T* pos_table, \
const T* type_table, \
const int* input_ids, \
const int* input_pos, \
const int* input_type, \
const int* input_mask, \
const int token_num, \
const int hidden_units, \
cudaStream_t stream)
INSTANTIATE_INVOKE_EMBEDDING_LOOKUP(float);
INSTANTIATE_INVOKE_EMBEDDING_LOOKUP(half);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_EMBEDDING_LOOKUP(__nv_bfloat16);
#endif
template<typename T>
__global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<T> param)
{
// 1. Copy the input ids to output ids and transpose output ids to [seq_len, batch_size, beam_width].
// 2. Embedding lookup by input ids and concat with soft prompt. The axis of concatenation is on axis of seq_len.
// Assume batch size is 2 and prompts are [[t1, t2], [t3], [t4, t5]], input_ids are [[s1, s2], [s3], [s4]]
// then the order of output_ids is
// [ [?, ?, s1, s2]
// [?, s3, padding, padding]
// [?, ?, s4, padding] ]
// and the order of embedding is
// [ [t1, t2, s1, s2]
// [t3, s3, padding, padding]
// [t4, t5, s4, padding] ]
// where "?" means undefined values and we should attach it.
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < param.batch_size * param.beam_width * (param.max_prefix_soft_prompt_length + param.max_input_length)
* param.hidden_units;
index += blockDim.x * gridDim.x) {
// transpose the input_ids [batch, length] (part of [batch, beam, max_input_length]) to
// output_ids [length, batch, beam].
// ouptut_ids need to add padding in the beginning for soft prompting.
if (index < param.batch_size * param.beam_width * param.max_input_length) {
int tmp_index = index;
const int seq_id = tmp_index % param.max_input_length;
tmp_index = (tmp_index - seq_id) / param.max_input_length;
const int beam_id = tmp_index % param.beam_width;
tmp_index = (tmp_index - beam_id) / param.beam_width;
const int batch_id = tmp_index % param.batch_size;
if (seq_id < param.max_input_length) {
param.output_ids[(param.prefix_soft_prompt_lengths[batch_id] + seq_id) * param.batch_size
* param.beam_width
+ batch_id * param.beam_width + beam_id] = param.input_ids[index];
}
}
// embedding lookup from word ids [batch, beam, length] (part of [batch, beam, max_input_length]), [vocab,
// hidden] and [batch, max_prefix_soft_prompt_length, hidden] to generate embedding [batch, beam, length +
// max_prefix_soft_prompt_length, hidden]
int tmp_index = index;
const int hidden_id = tmp_index % param.hidden_units;
tmp_index = (tmp_index - hidden_id) / param.hidden_units;
const int seq_id = tmp_index % (param.max_prefix_soft_prompt_length + param.max_input_length);
tmp_index = (tmp_index - seq_id) / (param.max_prefix_soft_prompt_length + param.max_input_length);
const int beam_id = tmp_index % param.beam_width;
tmp_index = (tmp_index - beam_id) / param.beam_width;
const int batch_id = tmp_index % param.batch_size;
const int64_t hidden_units = param.hidden_units;
T embedding =
(seq_id < param.prefix_soft_prompt_lengths[batch_id]) ?
(T)param.prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * hidden_units
+ seq_id * hidden_units + hidden_id] :
param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length
+ beam_id * param.max_input_length
+ (seq_id - param.prefix_soft_prompt_lengths[batch_id])]
* hidden_units
+ hidden_id];
T pos_embed = param.pos_table == nullptr ?
(T)0.0f :
param.pos_table[(param.start_step + seq_id - 1) * hidden_units + hidden_id];
param.from_tensor[index] = embedding + pos_embed;
if (seq_id == 0 && hidden_id == 0) {
param.input_lengths[batch_id * param.beam_width + beam_id] += param.prefix_soft_prompt_lengths[batch_id];
}
}
}
template<typename T>
void invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<T> param)
{
dim3 grid(min(param.batch_size * param.beam_width * (param.max_input_length + param.max_prefix_soft_prompt_length),
65536));
dim3 block(min(param.hidden_units, 512));
inputIdsEmbeddingLookupPosEncodingSoftPrompt<T><<<grid, block, 0, param.stream>>>(param);
}
template void
invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<float> param);
template void
invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<half> param);
#ifdef ENABLE_BF16
template void invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(
inputIdsEmbeddingLookupPosEncodingSoftPromptParam<__nv_bfloat16> param);
#endif
// TODO Add half2 implementation
template<typename T>
__global__ void transposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2)
{
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < dim0 * dim1 * dim2) {
const int input_dim2_index = index % dim2;
index = (index - input_dim2_index) / dim2;
const int input_dim1_index = index % dim1;
index = (index - input_dim1_index) / dim1;
const int input_dim0_index = index % dim0;
out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + input_dim2_index] =
in[input_dim0_index * dim1 * dim2 + input_dim1_index * dim2 + input_dim2_index];
}
}
template<typename T>
void invokeTransposeAxis012(T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream)
{
dim3 block(512);
dim3 grid((int)(ceil(dim0 * dim1 * dim2 / 512.)));
transposeAxis01<<<grid, block, 0, stream>>>(out, in, dim0, dim1, dim2);
}
template<typename T>
__global__ void transposeAxis01(T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1)
{
// out: [dim1, dim0]
// in: [dim0, dim1]
// in_skipping_dim1: [dim1]
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < dim0 * dim1) {
const int input_dim1_index = index % dim1;
index = (index - input_dim1_index) / dim1;
const int input_dim0_index = index % dim0;
const int in_offset = in_skipping_dim1 == nullptr ? 0 : in_skipping_dim1[input_dim1_index] * dim1;
out[input_dim1_index * dim0 + input_dim0_index] = in[in_offset + input_dim0_index * dim1 + input_dim1_index];
}
}
template<typename T>
void invokeTransposeAxis01(
T* out, T* in, const int dim0, const int dim1, cudaStream_t stream)
{
dim3 block(512);
dim3 grid((int)(ceil(dim0 * dim1 / 512.)));
transposeAxis01<<<grid, block, 0, stream>>>(out, in, nullptr, dim0, dim1);
}
#define DEFINE_INVOKETRANSPOSE(T) \
template void invokeTransposeAxis01(T* out, T* in, const int dim0, const int dim1, cudaStream_t stream); \
template void invokeTransposeAxis012( \
T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream)
DEFINE_INVOKETRANSPOSE(int32_t);
DEFINE_INVOKETRANSPOSE(int8_t);
DEFINE_INVOKETRANSPOSE(uint8_t);
DEFINE_INVOKETRANSPOSE(uint32_t);
DEFINE_INVOKETRANSPOSE(int64_t);
DEFINE_INVOKETRANSPOSE(uint64_t);
DEFINE_INVOKETRANSPOSE(float);
DEFINE_INVOKETRANSPOSE(half);
#ifdef ENABLE_BF16
DEFINE_INVOKETRANSPOSE(__nv_bfloat16);
#endif
#ifdef ENABLE_FP8
DEFINE_INVOKETRANSPOSE(__nv_fp8_e4m3);
#endif
template<typename T>
__global__ void transposeAxis12(T* out, T* in, const int dim0, const int dim1, const int dim2, const int dim3)
{
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < dim0 * dim1 * dim2 * dim3) {
const int input_dim3_index = index % dim3;
index = (index - input_dim3_index) / dim3;
const int input_dim2_index = index % dim2;
index = (index - input_dim2_index) / dim2;
const int input_dim1_index = index % dim1;
index = (index - input_dim1_index) / dim1;
const int input_dim0_index = index % dim0;
out[input_dim0_index * dim1 * dim2 * dim3 + input_dim2_index * dim1 * dim3 + input_dim1_index * dim3 + input_dim3_index] =
in[input_dim0_index * dim1 * dim2 * dim3 + input_dim1_index * dim2 * dim3 + input_dim2_index * dim3 + input_dim3_index];
}
}
template<typename T>
void invokeTransposeAxis12(T* out, T* in, const int dim0, const int dim1, const int dim2, const int dim_3, cudaStream_t stream)
{
dim3 block(512);
dim3 grid((int)(ceil(dim0 * dim1 * dim2 * dim_3 / 512.)));
transposeAxis12<<<grid, block, 0, stream>>>(out, in, dim0, dim1, dim2, dim_3);
}
template void
invokeTransposeAxis12(float* out, float* in, const int dim0, const int dim1, const int dim2, const int dim_3, cudaStream_t stream);
template void
invokeTransposeAxis12(half* out, half* in, const int dim0, const int dim1, const int dim2, const int dim_3, cudaStream_t stream);
template void
invokeTransposeAxis12(int* out, int* in, const int dim0, const int dim1, const int dim2, const int dim_3, cudaStream_t stream);
#ifdef ENABLE_BF16
template void
invokeTransposeAxis12(__nv_bfloat16* out, __nv_bfloat16* in, const int dim0, const int dim1, const int dim2, const int dim_3, cudaStream_t stream);
#endif
template<typename T, bool PREFIX_PROMPT, bool IS_CAUSAL>
__global__ void buildDecoderAttentionMaskKernel(T* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int max_seq_len,
const int max_prompt_length)
{
// sequence_lengths: [batch_size]
// attention_mask: [batch_size, 1, max_seq_len, max_seq_len + max_prompt_length]
const int max_prompt_seq_length = max_seq_len + max_prompt_length;
const int mask_size_per_seq = max_seq_len * max_prompt_seq_length;
attention_mask += blockIdx.x * mask_size_per_seq;
const int seq_length = sequence_lengths[blockIdx.x];
const int prompt_length = PREFIX_PROMPT ? prefix_prompt_lengths[blockIdx.x] : 0;
for (int i = threadIdx.x; i < mask_size_per_seq; i += blockDim.x) {
int row_id = i / max_prompt_seq_length;
int col_id = i % max_prompt_seq_length;
int column_bound = IS_CAUSAL ? row_id + prompt_length: seq_length - 1;
if (row_id < seq_length && col_id <= (column_bound)) {
attention_mask[i] = (T)(1.0f);
}
else {
attention_mask[i] = (T)(0.0f);
}
}
}
template<typename T>
void invokeBuildDecoderAttentionMask(T* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
const bool is_causal,
cudaStream_t stream)
{
#define RUN_KERNEL(has_prefix, is_causal) \
buildDecoderAttentionMaskKernel<T, has_prefix, is_causal><<<batch_size, 256, 0, stream>>>( \
attention_mask, sequence_lengths, prefix_prompt_lengths, max_seq_len, max_prompt_length)
if (max_prompt_length == 0) {
if (is_causal) {
RUN_KERNEL(false, true);
} else {
RUN_KERNEL(false, false);
}
}
else {
if (is_causal) {
RUN_KERNEL(true, true);
} else{
RUN_KERNEL(true, false);
}
}
}
template void invokeBuildDecoderAttentionMask(float* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
const bool is_causal,
cudaStream_t stream);
template void invokeBuildDecoderAttentionMask(half* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
const bool is_causal,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeBuildDecoderAttentionMask(__nv_bfloat16* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
const bool is_causal,
cudaStream_t stream);
#endif
#ifdef ENABLE_FP8
template void invokeBuildDecoderAttentionMask(__nv_fp8_e4m3* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
const bool is_causal,
cudaStream_t stream);
#endif
// The attention_mask only will be used in encode part, so just ignore the case when row_id >= length.
template<typename T>
__global__ void buildGlmDecoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len)
{
// sequence_lengths: [batch_size]
// attention_mask: [batch_size, 1, max_seq_len, max_seq_len]
attention_mask += blockIdx.x * max_seq_len * max_seq_len;
const int seq_length = sequence_lengths[blockIdx.x];
for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) {
int row_id = i / max_seq_len;
int col_id = i % max_seq_len;
if (row_id < seq_length && col_id <= row_id) {
attention_mask[i] = (T)(1.0f);
}
else if (col_id < seq_length - 1) {
attention_mask[i] = (T)(1.0f);
}
else {
attention_mask[i] = (T)(0.0f);
}
}
}
template<typename T>
void invokeBuildGlmDecoderAttentionMask(
T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream)
{
buildGlmDecoderAttentionMaskKernel<<<batch_size, 256, 0, stream>>>(attention_mask, sequence_lengths, max_seq_len);
}
template void invokeBuildGlmDecoderAttentionMask(float* attention_mask,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
template void invokeBuildGlmDecoderAttentionMask(half* attention_mask,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeBuildGlmDecoderAttentionMask(__nv_bfloat16* attention_mask,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
#endif
template<typename T>
__launch_bounds__(1024, 1) __global__ void lookupHiddenStateOfLastToken(T* from_tensor,
const T* hidden_state,
const int* input_lengths,
const int batch_size,
const int hidden_units,
const int idx_offset)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * hidden_units;
index += blockDim.x * gridDim.x) {
const int col_index = index % hidden_units;
const int batch_id = index / hidden_units;
from_tensor[index] = hidden_state[(input_lengths[batch_id] + idx_offset) * hidden_units + col_index];
}
}
template<typename T>
__launch_bounds__(1024, 1) __global__ void lookupHiddenStateOfFirstToken(T* from_tensor,
const T* hidden_state,
const int* input_lengths,
const int batch_size,
const int hidden_units)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * hidden_units;
index += blockDim.x * gridDim.x) {
const int col_index = index % hidden_units;
const int batch_id = index / hidden_units;
const int base_index = batch_id == 0 ? 0 : input_lengths[batch_id - 1] * hidden_units;
from_tensor[index] = hidden_state[base_index + col_index];
}
}
template<typename T>
void invokeLookupHiddenStateOfLastToken(T* from_tensor,
const T* hidden_state,
const int* input_lengths,
const int batch_size,
const int hidden_units,
const int idx_offset,
cudaStream_t stream)
{
const int grid_size = (int)(ceil(batch_size * hidden_units / 1024.));
dim3 grid(min(grid_size, 65536));
dim3 block(min(hidden_units, 1024));
lookupHiddenStateOfLastToken<T><<<grid, block, 0, stream>>>(
from_tensor, hidden_state, input_lengths, batch_size, hidden_units, idx_offset);
}
template<typename T>
void invokeLookupHiddenStateOfFirstToken(T* from_tensor,
const T* hidden_state,
const int* input_lengths,
const int batch_size,
const int hidden_units,
cudaStream_t stream)
{
const int grid_size = (int)(ceil(batch_size * hidden_units / 1024.));
dim3 grid(min(grid_size, 65536));
dim3 block(min(hidden_units, 1024));
lookupHiddenStateOfFirstToken<T><<<grid, block, 0, stream>>>(
from_tensor, hidden_state, input_lengths, batch_size, hidden_units);
}
#define INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(T) \
template void invokeLookupHiddenStateOfLastToken(T* from_tensor, \
const T* hidden_state, \
const int* input_lengths, \
const int batch_size, \
const int hidden_units, \
const int idx_offset, \
cudaStream_t stream)
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(float);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(half);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(int32_t);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(int8_t);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(uint8_t);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(uint32_t);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(int64_t);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(uint64_t);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(__nv_bfloat16);
#endif
#ifdef ENABLE_FP8
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_LAST(__nv_fp8_e4m3);
#endif
#define INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_FIRST(T) \
template void invokeLookupHiddenStateOfFirstToken(T* from_tensor, \
const T* hidden_state, \
const int* input_lengths, \
const int batch_size, \
const int hidden_units, \
cudaStream_t stream)
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_FIRST(float);
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_FIRST(half);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_LOOKUP_HIDDEN_OF_FIRST(__nv_bfloat16);
#endif
template<bool PREFIX_PROMPT>
__global__ void tileGptPromptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
int* tiled_prompt_lengths,
const int* input_ids,
const int* input_lengths,
const int* prefix_prompt_lengths,
const int max_input_length)
{
if (threadIdx.x == 0) {
tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x];
if (PREFIX_PROMPT) {
tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x];
}
}
for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) {
tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] =
input_ids[blockIdx.x * max_input_length + index];
}
}
void invokeTileGptPromptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
int* tiled_prompt_lengths,
const int* input_ids,
const int* input_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream)
{
dim3 grid(batch_size, beam_width);
dim3 block(min(1024, max_input_length));
if (prefix_prompt_lengths != nullptr) {
tileGptPromptInputs<true><<<grid, block, 0, stream>>>(tiled_input_ids,
tiled_input_lengths,
tiled_prompt_lengths,
input_ids,
input_lengths,
prefix_prompt_lengths,
max_input_length);
}
else {
tileGptPromptInputs<false><<<grid, block, 0, stream>>>(tiled_input_ids,
tiled_input_lengths,
tiled_prompt_lengths,
input_ids,
input_lengths,
prefix_prompt_lengths,
max_input_length);
}
}
void invokeTileGptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
const int* input_ids,
const int* input_lengths,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream)
{
invokeTileGptPromptInputs(tiled_input_ids,
tiled_input_lengths,
nullptr,
input_ids,
input_lengths,
nullptr,
batch_size,
beam_width,
max_input_length,
stream);
}
#if USING_CUDA
template<int TB_SIZE>
__global__ void
find_context_dups(int* shared_contexts, const int* input_ids, const size_t batch_size, const size_t input_seq_len)
{
/* We compare all context pairs (i, j), with i (tgt) < j (src) , to detect duplicate
* inputs. If there's a match between i and j, we store i at the
* j-th position of shared_context. So that we know that j can be
* represented by i. shared_contexts is initialized like shared_contexts[i] = i
* and when there's a match, we actually use shared_contexts[j] = min(shared_contexts[j], i)
* so that in the end, shared_contexts effectively contains an index
* to the match with the lowest index context.
* Note that shared_contexts[i] <= i, a property that will be used when uncompacting
* inputs.
*/
typedef cub::BlockReduce<int, TB_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ bool match;
/* Each block is responsible for a (i, j) pair. To map the block space to
* the i < j space, we need to convert a linear addressing to a triangle, of
* size (batch_size * (batch_size - 1)) / 2
* For more information, check https://en.wikipedia.org/wiki/Triangular_number
*/
// blockIdx = [0, 1, 2, ... n(n-1)/2] -> base_index = [0, 1, 1, 2, 2, 2, 3, 3, 3, 3, ..., n - 2]
const int base_index = floorf(0.5f * (sqrtf(1 + 8 * blockIdx.x) - 1));
const int src_idx = base_index + 1; // base_index \in [1, batch_size)
const int rev_base_index = base_index * (base_index + 1) / 2;
const int tgt_idx = blockIdx.x - rev_base_index; // tgt_idx \in [0, src_idx)
const int padded_length = TB_SIZE * ((input_seq_len + TB_SIZE - 1) / TB_SIZE);
int sum = 0;
for (int i = threadIdx.x; i < padded_length; i += TB_SIZE) {
int compare =
(i >= input_seq_len) ? 1 : input_ids[tgt_idx * input_seq_len + i] == input_ids[src_idx * input_seq_len + i];
sum = BlockReduce(temp_storage).Sum(compare);
if (threadIdx.x == 0) {
match = (sum == TB_SIZE);
}
__syncthreads();
if (!match) {
break;
}
}
if (threadIdx.x == 0 && match) {
atomicMin(&shared_contexts[src_idx], tgt_idx);
}
}
constexpr int DUPS_INDICES_BLOCK_SIZE = 128;
__global__ void generate_dups_indices(int* batch_to_compact,
int* compact_to_batch,
int* compact_size,
const int* shared_contexts,
const size_t batch_size,
const size_t input_seq_len)
{
const int padded_batchsize = blockDim.x * ((batch_size + blockDim.x - 1) / blockDim.x);
typedef cub::BlockScan<int, DUPS_INDICES_BLOCK_SIZE, cub::BLOCK_SCAN_WARP_SCANS> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ int scan_offset;
int scan = 0;
for (int batch = threadIdx.x; batch < padded_batchsize; batch += blockDim.x) {
bool masked = (batch >= batch_size);
bool first_iter = batch < blockDim.x;
int is_first_occur = masked ? 0 : shared_contexts[batch] == batch;
BlockScan(temp_storage).ExclusiveSum(is_first_occur, scan);
if (!masked && is_first_occur) {
int compact_idx = scan + (first_iter ? 0 : scan_offset);
// Context rep. writes initial index
batch_to_compact[batch] = compact_idx;
compact_to_batch[compact_idx] = batch;
}
if (threadIdx.x == blockDim.x - 1) {
scan_offset = scan + is_first_occur + (first_iter ? 0 : scan_offset);
}
__syncthreads();
if (!masked && !is_first_occur) {
// Fill the rest of batch_to_compact based on what rep. wrote
const int src_idx = batch_to_compact[shared_contexts[batch]];
batch_to_compact[batch] = src_idx;
}
}
if (threadIdx.x == 0) {
*compact_size = scan_offset;
}
}
__global__ void init_shared_contexts(int* shared_contexts, const size_t batch_size)
{
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= batch_size) {
return;
}
shared_contexts[global_idx] = global_idx;
}
void invokeFindContextDups(int* shared_contexts,
int* batch_to_compact,
int* compact_to_batch,
int* compact_size,
const int* input_ids,
const size_t batch_size,
const size_t input_seq_len,
cudaStream_t stream)
{
dim3 block{512};
dim3 grid{((int)batch_size + block.x - 1) / block.x};
init_shared_contexts<<<grid, block, 0, stream>>>(shared_contexts, batch_size);
grid = dim3{(unsigned int)(batch_size * (batch_size - 1)) / 2};
if (input_seq_len <= 128) {
block = 128;
find_context_dups<128><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
}
else {
block = 256;
find_context_dups<256><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
}
generate_dups_indices<<<1, DUPS_INDICES_BLOCK_SIZE, 0, stream>>>(
batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, input_seq_len);
}
#endif
template<typename T>
__global__ void compact_inputs(T* compact_input,
T* compact_attention_mask,
int* compact_input_lengths,
const T* decoder_input,
const T* decoder_mask,
const int* input_lengths,
const int* compact_idx,
size_t compact_size,
size_t seq_len,
size_t hidden_dimension)
{
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx < compact_size * seq_len * hidden_dimension) {
const int h_id = global_idx % hidden_dimension;
const int seq_id = (global_idx / hidden_dimension) % seq_len;
const int batch_id = global_idx / (hidden_dimension * seq_len);
compact_input[global_idx] = decoder_input[(compact_idx[batch_id] * seq_len + seq_id) * hidden_dimension + h_id];
}
if (global_idx < compact_size * seq_len * seq_len) {
const int seq1_id = global_idx % seq_len;
const int seq2_id = (global_idx / seq_len) % seq_len;
const int batch_id = global_idx / (seq_len * seq_len);
compact_attention_mask[global_idx] =
decoder_mask[(compact_idx[batch_id] * seq_len + seq2_id) * seq_len + seq1_id];
}
if (global_idx < compact_size) {
compact_input_lengths[global_idx] = input_lengths[compact_idx[global_idx]];
}
}
template<typename T>
void invokeCompactInputs(T* compact_input,
T* compact_attention_mask,
int* compact_input_lengths,
const T* decoder_input,
const T* decoder_mask,
const int* input_lengths,
const int* compact_idx,
size_t compact_size,
size_t seq_len,
size_t hidden_dimension,
cudaStream_t stream)
{
/* Compact relevant decoder_layer inputs based on the identical contexts.
* For example, decoder_input is [batch_size, seq_len, H]. It's compacted
* into compact_input [compact_size, seq_len, H] such that
* compact_input[i, ...] = decoder_input[compact_idx[i], ...] */
const size_t elems_n = compact_size * seq_len * max(hidden_dimension, seq_len);
const dim3 blockDim(512);
const dim3 gridDim((elems_n + 512 - 1) / 512);
compact_inputs<T><<<gridDim, blockDim, 0, stream>>>(compact_input,
compact_attention_mask,
compact_input_lengths,
decoder_input,
decoder_mask,
input_lengths,
compact_idx,
compact_size,
seq_len,
hidden_dimension);
}
#define INSTANTIATE_INVOKE_COMPACT_INPUTS(T) \
template void invokeCompactInputs<T>(T * compact_input, \
T * compact_attention_mask, \
int* compact_input_lengths, \
const T* decoder_input, \
const T* decoder_mask, \
const int* input_lengths, \
const int* compact_idx, \
size_t compact_size, \
size_t seq_len, \
size_t hidden_dimension, \
cudaStream_t stream)
INSTANTIATE_INVOKE_COMPACT_INPUTS(half);
INSTANTIATE_INVOKE_COMPACT_INPUTS(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_COMPACT_INPUTS(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_COMPACT_INPUTS
template<typename T>
__global__ void uncompact_outputs(T* uncompact_buffer,
const T* compact_buffer,
const int* batch_to_compact_idx,
size_t batch_size,
size_t buffer_stride)
{
/* Uncompact a buffer IN of size [Compact, Stride] into OUT of size [Batch, Stride]
* so that \forall i, OUT[i, :] = IN[batch_to_compact_idx[i], :]
*/
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= batch_size * buffer_stride) {
return;
}
const int stride_idx = global_idx % buffer_stride;
const int batch_idx = global_idx / buffer_stride;
const int src = batch_to_compact_idx[batch_idx];
uncompact_buffer[global_idx] = compact_buffer[src * buffer_stride + stride_idx];
}
template<typename T>
void invokeUnCompactOutputs(T* uncompact_buffer,
const T* compact_buffer,
const int* batch_to_compact_idx,
size_t batch_size,
size_t buffer_stride,
cudaStream_t stream)
{
const size_t num_elems = batch_size * buffer_stride;
const dim3 blockDim(1024);
const dim3 gridDim((num_elems + blockDim.x - 1) / blockDim.x);
uncompact_outputs<T><<<gridDim, blockDim, 0, stream>>>(
uncompact_buffer, compact_buffer, batch_to_compact_idx, batch_size, buffer_stride);
}
#define INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(T) \
template void invokeUnCompactOutputs(T* uncompact_buffer, \
const T* compact_buffer, \
const int* batch_to_compact_idx, \
size_t batch_size, \
size_t buffer_stride, \
cudaStream_t stream)
INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(half);
INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS
template<typename T>
__global__ void uncompact_caches(T* uncompact_k_cache,
T* uncompact_v_cache,
const T* compact_k_cache,
const T* compact_v_cache,
const int* batch_to_compact_idx,
size_t batch_size,
size_t num_heads,
size_t max_seq_len,
size_t seq_len,
size_t size_per_head,
size_t local_batch_size,
size_t ite)
{
const int hidden_dimension = num_heads * size_per_head;
const int num_elems_per_batch = seq_len * hidden_dimension;
const int num_elems_cache = batch_size * num_elems_per_batch;
const int x_size = 16 / sizeof(T);
for (int global_idx = blockIdx.x * blockDim.x + threadIdx.x; global_idx < 2 * num_elems_cache;
global_idx += blockDim.x * gridDim.x) {
const bool handle_k = global_idx < num_elems_cache;
const T* const cache_src = handle_k ? compact_k_cache : compact_v_cache;
T* const cache_dst = handle_k ? uncompact_k_cache : uncompact_v_cache;
const int idx = handle_k ? global_idx : global_idx - num_elems_cache;
const int src_offset = idx % num_elems_per_batch;
const int batch_idx = idx / num_elems_per_batch;
const int batch_src = batch_to_compact_idx[batch_idx] - ite * local_batch_size;
if (batch_src < 0 || batch_src >= local_batch_size) {
continue;
}
int dst_offset;
if (handle_k) {
const int i0 = idx % (x_size * seq_len);
const int i1 = (idx / (x_size * seq_len)) % (num_heads * size_per_head / x_size);
dst_offset = i1 * max_seq_len * x_size + i0;
}
else {
const int i0 = idx % (size_per_head * seq_len);
const int i1 = (idx / (size_per_head * seq_len)) % (num_heads);
dst_offset = i1 * max_seq_len * size_per_head + i0;
}
cache_dst[batch_idx * max_seq_len * hidden_dimension + dst_offset] =
cache_src[batch_src * num_elems_per_batch + src_offset];
}
}
template<typename T>
void invokeUnCompactCaches(T* uncompact_k_cache,
T* uncompact_v_cache,
const T* compact_k_cache,
const T* compact_v_cache,
const int* batch_to_compact_idx,
size_t batch_size,
size_t num_heads,
size_t max_seq_len,
size_t seq_len,
size_t size_per_head,
size_t local_batch_size,
size_t ite,
cudaStream_t stream)
{
const dim3 blockDim(512);
const dim3 gridDim(1024);
uncompact_caches<T><<<gridDim, blockDim, 0, stream>>>(uncompact_k_cache,
uncompact_v_cache,
compact_k_cache,
compact_v_cache,
batch_to_compact_idx,
batch_size,
num_heads,
max_seq_len,
seq_len,
size_per_head,
local_batch_size,
ite);
}
#define INSTANTIATE_INVOKE_UNCOMPACT_CACHES(T) \
template void invokeUnCompactCaches(T* uncompact_k_cache, \
T* uncompact_v_cache, \
const T* compact_k_cache, \
const T* compact_v_cache, \
const int* batch_to_compact_idx, \
size_t batch_size, \
size_t num_heads, \
size_t max_seq_len, \
size_t seq_len, \
size_t size_per_head, \
size_t local_batch_size, \
size_t ite, \
cudaStream_t stream)
INSTANTIATE_INVOKE_UNCOMPACT_CACHES(half);
INSTANTIATE_INVOKE_UNCOMPACT_CACHES(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_UNCOMPACT_CACHES(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_UNCOMPACT_CACHES
template<bool PREFIX_PROMPT>
__global__ void update_padding_count(int* total_padding_count,
const int* input_lengths,
const int* tiled_prompt_lengths,
size_t max_input_length,
size_t max_prompt_length,
size_t batch_size,
size_t beam_width)
{
const int gidx = blockIdx.x * blockDim.x + threadIdx.x;
if (gidx >= batch_size * beam_width) {
return;
}
const int batch_idx = gidx / beam_width;
total_padding_count[gidx] +=
PREFIX_PROMPT ? (max_input_length + max_prompt_length - input_lengths[batch_idx] - tiled_prompt_lengths[gidx]) :
(max_input_length - input_lengths[batch_idx]);
}
void invokeUpdatePaddingCount(int* total_padding_count,
const int* input_lengths,
const int* tiled_prompt_lengths,
size_t max_input_length,
size_t max_prompt_length,
size_t batch_size,
size_t beam_width,
cudaStream_t stream)
{
dim3 blockSize(256);
dim3 gridSize((batch_size * beam_width + blockSize.x - 1) / blockSize.x);
if (tiled_prompt_lengths != nullptr) {
update_padding_count<true><<<gridSize, blockSize, 0, stream>>>(total_padding_count,
input_lengths,
tiled_prompt_lengths,
max_input_length,
max_prompt_length,
batch_size,
beam_width);
}
else {
update_padding_count<false><<<gridSize, blockSize, 0, stream>>>(total_padding_count,
input_lengths,
tiled_prompt_lengths,
max_input_length,
max_prompt_length,
batch_size,
beam_width);
}
}
template<bool PREFIX_PROMPT>
__global__ void mask_padding_tokens(bool* masked_tokens,
const int* input_lengths,
const int* tiled_prefix_prompt_lengths,
const size_t memory_len,
const size_t max_input_length,
const size_t initial_step,
size_t beam_width)
{
const int seq_len = PREFIX_PROMPT ?
(input_lengths[blockIdx.x / beam_width] + tiled_prefix_prompt_lengths[blockIdx.x]) :
input_lengths[blockIdx.x / beam_width];
for (int step = initial_step + seq_len + threadIdx.x; step < initial_step + max_input_length; step += blockDim.x) {
masked_tokens[blockIdx.x * memory_len + step % memory_len] = true;
}
}
void invokeMaskPaddingTokens(bool* masked_tokens,
const int* input_lengths,
const int* tiled_prefix_prompt_lengths,
const size_t memory_len,
const size_t max_input_length,
const size_t initial_step,
size_t batch_size,
size_t beam_width,
cudaStream_t stream)
{
dim3 blockSize(128);
dim3 gridSize(batch_size * beam_width);
if (tiled_prefix_prompt_lengths != nullptr) {
mask_padding_tokens<true><<<gridSize, blockSize, 0, stream>>>(masked_tokens,
input_lengths,
tiled_prefix_prompt_lengths,
memory_len,
max_input_length,
initial_step,
beam_width);
}
else {
mask_padding_tokens<false><<<gridSize, blockSize, 0, stream>>>(masked_tokens,
input_lengths,
tiled_prefix_prompt_lengths,
memory_len,
max_input_length,
initial_step,
beam_width);
}
}
template<typename T>
__global__ void sum_length_dimension(
float* out_buf, const T* in_buf, const size_t batch_size, const size_t input_length, const size_t hidden_dim)
{
const int bidx = blockIdx.x;
for (int hidx = threadIdx.x; hidx < hidden_dim; hidx += blockDim.x) {
float accum = 0.0f;
for (int step = 0; step < input_length; step++) {
accum += static_cast<float>(in_buf[(bidx * input_length + step) * hidden_dim + hidx]);
}
out_buf[bidx * hidden_dim + hidx] = accum;
}
}
template<typename T>
void invokeSumLengthDimension(float* out_buf,
const T* in_buf,
const size_t batch_size,
const size_t input_length,
const size_t hidden_dim,
cudaStream_t stream)
{
dim3 gridSize(batch_size);
dim3 blockSize(256);
sum_length_dimension<<<gridSize, blockSize, 0, stream>>>(out_buf, in_buf, batch_size, input_length, hidden_dim);
}
__global__ void ConvertOffsetToAddr(uint64_t* block_addr, // [l, b, 2, m]
const uint64_t* k_cache_base_addr, // [l]
const uint64_t* v_cache_base_addr,
const int* offset, // [b, m]
int layer_num,
int batch_size,
int max_block_num,
int block_size)
{
const int layer_stride = batch_size * 2 * max_block_num;
const int batch_stride = 2 * max_block_num;
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < layer_num * batch_size * max_block_num;
index += blockDim.x * gridDim.x) {
const int layer_index = index / max_block_num / batch_size;
const int batch_index = (index / max_block_num) % batch_size;
const int col_index = index % max_block_num;
const size_t block_offset = (size_t)offset[batch_index * max_block_num + col_index] * block_size;
const size_t block_addr_index = (size_t)layer_index * layer_stride + batch_index * batch_stride + col_index;
block_addr[block_addr_index] = k_cache_base_addr[layer_index] + block_offset;
block_addr[block_addr_index + max_block_num] = v_cache_base_addr[layer_index] + block_offset;
}
}
void invokeConvertOffsetToAddr(uint64_t* block_addr, // [l, b, 2, m]
const uint64_t* k_cache_base_addr, // [l]
const uint64_t* v_cache_base_addr,
const int* offset, // [b, m]
int layer_num,
int batch_size,
int max_block_num,
int block_size,
cudaStream_t stream) {
dim3 grid(min(batch_size * layer_num, 65536));
dim3 block(min(max_block_num, 1024));
ConvertOffsetToAddr<<<grid, block, 0, stream>>>(block_addr, // [l, b, 2, m]
k_cache_base_addr, // [l]
v_cache_base_addr,
offset, // [b, m]
layer_num,
batch_size,
max_block_num,
block_size);
}
__global__ void ConvertOffsetToAddrOneLayer(uint64_t* block_addr, // [b, 2, m]
const uint64_t k_cache_base_addr,
const uint64_t v_cache_base_addr,
const int* offset, // [b, m]
int batch_size,
int max_block_num,
int block_size)
{
const int batch_stride = 2 * max_block_num;
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * max_block_num;
index += blockDim.x * gridDim.x) {
const int batch_index = index / max_block_num;
const int col_index = index % max_block_num;
const size_t block_offset = (size_t)offset[batch_index * max_block_num + col_index] * block_size;
const size_t block_addr_index = (size_t)batch_index * batch_stride + col_index;
block_addr[block_addr_index] = k_cache_base_addr + block_offset;
block_addr[block_addr_index + max_block_num] = v_cache_base_addr + block_offset;
}
}
__global__ void ConvertOffsetToBlockArrayData(int32_t* offset_addr,
const int* offset, // [b, m]
int batch_size,
int max_block_num,
int kv_block_offset)
{
const int batch_stride = 2 * max_block_num;
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * max_block_num;
index += blockDim.x * gridDim.x) {
const int batch_index = index / max_block_num;
const int col_index = index % max_block_num;
const int32_t block_offset = (int32_t)offset[batch_index * max_block_num + col_index];
const int32_t block_addr_index = (int32_t)batch_index * batch_stride + col_index;
offset_addr[block_addr_index] = block_offset;
offset_addr[block_addr_index + max_block_num] = block_offset + kv_block_offset;
}
}
void invokeConvertOffsetToAddrOneLayer(uint64_t* block_addr, // [b, 2, m]
const uint64_t k_cache_base_addr,
const uint64_t v_cache_base_addr,
const int* offset, // [b, m]
int batch_size,
int max_block_num,
int block_size,
cudaStream_t stream) {
dim3 grid(min(batch_size, 65536));
dim3 block(min(max_block_num, 1024));
ConvertOffsetToAddrOneLayer<<<grid, block, 0, stream>>>(block_addr, // [b, 2, m]
k_cache_base_addr,
v_cache_base_addr,
offset, // [b, m]
batch_size,
max_block_num,
block_size);
}
void invokeConvertOffsetToBlockArrayData(int32_t* offset_addr, // [b, 2, m]
const int* offset, // [b, m]
int batch_size,
int max_block_num,
int kv_block_offset,
cudaStream_t stream) {
dim3 grid(min(batch_size, 65536));
dim3 block(min(max_block_num, 1024));
ConvertOffsetToBlockArrayData<<<grid, block, 0, stream>>>(offset_addr, // [b, 2, m]
offset, // [b, m]
batch_size,
max_block_num,
kv_block_offset);
}
#define INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(T) \
template void invokeSumLengthDimension(float* out_buf, \
const T* in_buf, \
const size_t batch_size, \
const size_t input_length, \
const size_t hidden_dim, \
cudaStream_t stream)
INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(half);
INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION
__global__ void getPaddingOffsetAndCuSeqLensKernel(int* padding_offset,
int* cu_seqlens,
const int* sequence_length,
const int batch_size,
const int max_seq_len)
{
// do cumulated sum
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
const bool calculate_cu_seqlens = cu_seqlens != nullptr;
for (int i = 0; i < batch_size; i++) {
const int seq_len = sequence_length[i];
if (calculate_cu_seqlens) {
cu_seqlens[i] = total_seq_len;
}
for (int j = 0; j < seq_len; j++) {
padding_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
if (calculate_cu_seqlens) {
cu_seqlens[batch_size] = total_seq_len;
}
}
__global__ void getCuSeqLensKernel(int* cu_seqlens,
const int* sequence_length,
const int* prefix_length,
const int batch_size) {
// do cumulated sum
int total_seq_len = 0;
const bool has_prefix_length = prefix_length != nullptr;
for (int i = 0; i < batch_size; i++) {
int seq_len = sequence_length[i];
if (has_prefix_length) {
seq_len += prefix_length[i];
}
cu_seqlens[i] = total_seq_len;
total_seq_len += seq_len;
}
cu_seqlens[batch_size] = total_seq_len;
}
void invokeGetPaddingOffsetAndCuSeqLens(int* padding_offset,
int* cu_seqlens,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream) {
getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>(
padding_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len);
sync_check_cuda_error();
}
void invokeGetCuSeqLens(int* cu_seqlens,
const int* sequence_length,
const int* prefix_length,
const int batch_size,
cudaStream_t stream) {
getCuSeqLensKernel<<<1, 1, 0, stream>>>(
cu_seqlens, sequence_length, prefix_length, batch_size);
sync_check_cuda_error();
}
template<typename T, int ELEM_PER_THREAD>
__global__ void scatter_add_stable_kernel(T const* src, int N, int K, int32_t const* index, T* out) {
// 在输出位置上并行,每个线程负责一个输出位置的累加
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
out_idx *= ELEM_PER_THREAD;
// 计算当前输出元素对应的维度
const int k = out_idx % K;
const int out_n = out_idx / K;
if(out_n >= N) return;
// 对每个输入位置检查,如果它们映射到当前输出位置则累加
#pragma unroll
for(int i = 0; i < ELEM_PER_THREAD; i++) {
if(out_idx + i < (size_t)N * K) {
T sum = out[out_idx + i];
// 遍历所有输入,找到映射到当前输出位置的元素
for(int in_n = 0; in_n < N; in_n++) {
if(index[in_n] == out_n) {
sum = sum + src[in_n * K + k + i];
}
}
out[out_idx + i] = sum;
}
}
}
template<typename T>
void invokeScatterAddStable(T const* src, int N, int K, int32_t const* index, T* out, cudaStream_t stream) {
const int num_threads = 256;
const int elem_per_thread = 4;
const dim3 block(num_threads);
RTP_LLM_CHECK(K % (elem_per_thread * 2) == 0);
auto h_index = std::shared_ptr<int32_t[]>(new int32_t[N], std::default_delete<int32_t[]>());
cudaMemcpy(h_index.get(), index, N * sizeof(int32_t), cudaMemcpyDeviceToHost);
int32_t max_out_n = h_index[0];
for(int i = 1; i < N; i++) {
max_out_n = max(max_out_n, h_index[i]);
}
max_out_n++;
if constexpr (std::is_same<T, float>::value) {
const dim3 grid(((size_t)max_out_n * K + num_threads * elem_per_thread - 1) / (num_threads * elem_per_thread));
scatter_add_stable_kernel<float, elem_per_thread><<<grid, block, 0, stream>>>(src, N, K, index, out);
} else if (K % 2 == 0) {
#if USING_ROCM
using Tp = typename rocm::packed_type_2<T>::type;
#else
using Tp = typename packed_type_2<T>::type;
#endif
const dim3 grid(((size_t)max_out_n * K / 2 + num_threads * elem_per_thread - 1) / (num_threads * elem_per_thread));
scatter_add_stable_kernel<Tp, elem_per_thread><<<grid, block, 0, stream>>>((Tp*)src, N, K / 2, index, (Tp*)out);
} else {
throw std::invalid_argument("scatter add unsupport type or K [%d]" + std::to_string(K));
}
}
template<typename T, int ELEM_PER_THREAD>
__global__ void scatter_add_kernel(T const* src, int N, int K, int32_t const* index, T* out) {
int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
thread_idx *= ELEM_PER_THREAD;
// int offset = blockDim.x * gridDim.x;
int k = thread_idx % K;
int64_t new_idx = (int64_t)index[thread_idx / K] * K;
#pragma unroll
for (int i = 0; i < ELEM_PER_THREAD; ++i) {
if (thread_idx + i < (size_t)N * K) {
#if USING_ROCM
#ifdef ENABLE_BF16
if constexpr (std::is_same<T, __nv_bfloat162>::value) {
unsafeAtomicAdd(reinterpret_cast<__hip_bfloat162*>(out) + new_idx + k + i, (__hip_bfloat162)src[thread_idx + i]);
} else {
unsafeAtomicAdd(out + new_idx + k + i, src[thread_idx + i]);
}
#else
unsafeAtomicAdd(out + new_idx + k + i, src[thread_idx + i]);
#endif
#else
atomicAdd(out + new_idx + k + i, src[thread_idx + i]);
#endif
}
}
}
template<typename T>
void invokeScatterAdd(T const* src, int N, int K, int32_t const* index, T* out, bool use_stable_scatter_add, cudaStream_t stream) {
RTP_LLM_CHECK_WITH_INFO(N > 0 && K > 0, "N and K must be greater than 0");
if (use_stable_scatter_add) {
invokeScatterAddStable(src, N, K, index, out, stream);
return;
}
const int num_threads = 256;
const int elem_per_thread = 4;
const dim3 block(num_threads);
RTP_LLM_CHECK(K % (elem_per_thread * 2) == 0);
if constexpr (std::is_same<T, float>::value) {
const dim3 grid(((size_t)N * K + num_threads * elem_per_thread - 1) / (num_threads * elem_per_thread));
scatter_add_kernel<float, elem_per_thread><<<grid, block, 0, stream>>>(src, N, K, index, out);
} else if (K % 2 == 0) {
#if USING_ROCM
using Tp = typename rocm::packed_type_2<T>::type;
#else
using Tp = typename packed_type_2<T>::type;
#endif
const dim3 grid(((size_t)N * K / 2 + num_threads * elem_per_thread - 1) / (num_threads * elem_per_thread));
scatter_add_kernel<Tp, elem_per_thread><<<grid, block, 0, stream>>>((Tp*)src, N, K / 2, index, (Tp*)out);
} else {
throw std::invalid_argument("scatter add unsupport type or K [%d]" + std::to_string(K));
}
}
#define INSTANTIATE_INVOKE_SCATTER_ADD(T) \
template void invokeScatterAdd(T const* src, int N, int K, int32_t const* index, T* out, bool use_stable_scatter_add, cudaStream_t stream)
INSTANTIATE_INVOKE_SCATTER_ADD(half);
INSTANTIATE_INVOKE_SCATTER_ADD(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_SCATTER_ADD(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_SCATTER_ADD
template<typename T>
__global__ void sliceDim1CopyKernel(T const* src, int dim0, int dim1, int dim1_start, int dim1_size, T* out) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < dim0 * dim1_size; index += blockDim.x * gridDim.x) {
const int col_index = index % dim1_size;
const size_t batch_id = index / dim1_size;
out[index] = src[batch_id * dim1 + dim1_start + col_index];
}
}
template<typename T>
void invokeSliceDim1Copy(T const* src, int dim0, int dim1, int dim1_start, int dim1_size, T* out, cudaStream_t stream) {
if constexpr (std::is_same<uint8_t, T>::value) {
if (dim1 % 16 == 0 && dim1_start % 16 == 0 && dim1_size % 16 == 0) {
dim1 /= 16;
dim1_start /= 16;
dim1_size /= 16;
const int grid_size = (int)(ceil((size_t)dim0 * dim1_size / 512.));
dim3 grid(min(grid_size, 65536));
dim3 block(512);
sliceDim1CopyKernel<uint4>
<<<grid, block, 0, stream>>>((uint4 const*)src, dim0, dim1, dim1_start, dim1_size, (uint4*)out);
} else if (dim1 % 8 == 0 && dim1_start % 8 == 0 && dim1_size % 8 == 0) {
dim1 /= 8;
dim1_start /= 8;
dim1_size /= 8;
const int grid_size = (int)(ceil((size_t)dim0 * dim1_size / 512.));
dim3 grid(min(grid_size, 65536));
dim3 block(512);
sliceDim1CopyKernel<uint2>
<<<grid, block, 0, stream>>>((uint2 const*)src, dim0, dim1, dim1_start, dim1_size, (uint2*)out);
} else if (dim1 % 4 == 0 && dim1_start % 4 == 0 && dim1_size % 4 == 0) {
dim1 /= 4;
dim1_start /= 4;
dim1_size /= 4;
const int grid_size = (int)(ceil((size_t)dim0 * dim1_size / 512.));
dim3 grid(min(grid_size, 65536));
dim3 block(512);
sliceDim1CopyKernel<uint>
<<<grid, block, 0, stream>>>((uint const*)src, dim0, dim1, dim1_start, dim1_size, (uint*)out);
} else {
const int grid_size = (int)(ceil((size_t)dim0 * dim1_size / 512.));
dim3 grid(min(grid_size, 65536));
dim3 block(512);
sliceDim1CopyKernel<T><<<grid, block, 0, stream>>>(src, dim0, dim1, dim1_start, dim1_size, out);
}
} else {
const int grid_size = (int)(ceil((size_t)dim0 * dim1_size / 512.));
dim3 grid(min(grid_size, 65536));
dim3 block(512);
sliceDim1CopyKernel<T><<<grid, block, 0, stream>>>(src, dim0, dim1, dim1_start, dim1_size, out);
}
}
#define INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(T) \
template void invokeSliceDim1Copy( \
T const* src, int dim0, int dim1, int dim1_start, int dim1_size, T* out, cudaStream_t stream)
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(float);
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(half);
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(int32_t);
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(int8_t);
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(uint8_t);
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(uint32_t);
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(int64_t);
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(uint64_t);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(__nv_bfloat16);
#endif
#ifdef ENABLE_FP8
INSTANTIATE_INVOKE_SlICE_DIM1_COPTY(__nv_fp8_e4m3);
#endif
__global__ void fakeBalanceExpertKernel(int* expert, float* expert_scales, int start, int expert_num, int size) {
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
expert[index] = (start + index) % expert_num;
expert_scales[index] = 1.0f;
}
}
void fake_balance_expert(int* expert, float* expert_scales, int start, int expert_num, int size, cudaStream_t stream) {
fakeBalanceExpertKernel<<<(size + 255) / 256, 256, 0, stream>>>(expert, expert_scales, start, expert_num, size);
}
} // namespace rtp_llm