maga_transformer/cpp/kernels/sampling_topk_kernels.cu (575 lines of code) (raw):

/* * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <stdexcept> #if USING_CUDA #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) #include <cub/cub.cuh> #else #include "3rdparty/cub/cub.cuh" #endif #include "maga_transformer/cpp/cuda/cuda_utils.h" #endif #if USING_ROCM #include <hipcub/hipcub.hpp> #include "maga_transformer/cpp/rocm/hip_utils.h" using namespace rtp_llm::rocm; #endif #include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh" #include "maga_transformer/cpp/kernels/sampling_topk_kernels.h" namespace rtp_llm { __global__ void curandInitialize(curandState_t* state, const int size, const unsigned long long random_seed) { if (threadIdx.x + blockIdx.x * blockDim.x < size) { curand_init(random_seed, 0, 0, &state[blockIdx.x * blockDim.x + threadIdx.x]); } } void invokeCurandInitialize(curandState_t* state, const size_t batch_size, const unsigned long long random_seed, cudaStream_t stream) { dim3 block(256); dim3 grid((int)(ceil(batch_size * 1.0 / 256))); curandInitialize<<<grid, block, 0, stream>>>(state, batch_size, random_seed); } __global__ void curandBatchInitialize(curandState_t* states, const int size, const unsigned long long* random_seeds) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < size) { curand_init(random_seeds[idx], 0, 0, &states[idx]); } } void invokeCurandBatchInitialize(curandState_t* states, const size_t batch_size, const unsigned long long* random_seeds, cudaStream_t stream) { dim3 block(256); dim3 grid((int)(ceil(batch_size * 1.0 / 256))); curandBatchInitialize<<<grid, block, 0, stream>>>(states, batch_size, random_seeds); } template<typename T> __global__ void addBiasEndMask(T* logits, const T* bias, const int* end_ids, const bool* finished, const int vocab_size, const int vocab_size_padded) { int bid = blockIdx.x; bool finish = finished != nullptr ? finished[bid] : false; int offset = bid * vocab_size_padded; const bool IS_FP16 = std::is_same<T, half>::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; for (int tid = threadIdx.x; tid < vocab_size_padded; tid += blockDim.x) { if (tid >= vocab_size) { logits[offset + tid] = -MAX_T_VAL; } else if (finish) { logits[offset + tid] = (tid == end_ids[bid]) ? MAX_T_VAL : -MAX_T_VAL; } else { if (bias != nullptr) { logits[offset + tid] += bias[tid]; } } } } template<typename T> void invokeAddBiasEndMask(T* logits, const T* bias, const int* end_ids, const bool* finished, const int batch_size, const int vocab_size, const int vocab_size_padded, cudaStream_t stream) { dim3 grid(batch_size); dim3 block(min(vocab_size_padded, 1024)); /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ addBiasEndMask<<<grid, block, 0, stream>>>(logits, bias, end_ids, finished, vocab_size, vocab_size_padded); } template void invokeAddBiasEndMask(float* logits, const float* bias, const int* end_ids, const bool* finished, const int batch_size, const int vocab_size, const int vocab_size_padded, cudaStream_t stream); template void invokeAddBiasEndMask(half* logits, const half* bias, const int* end_ids, const bool* finished, const int batch_size, const int vocab_size, const int vocab_size_padded, cudaStream_t stream); template<typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_> __global__ void topk_stage1(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, const bool* finished, const int max_top_k, const int* top_ks, const int vocab_size, const int* end_ids, const bool* skip_decode) { typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; const int tid = threadIdx.x; const int bid = blockIdx.x; const int batch_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs if (skip_decode != nullptr && skip_decode[batch_id]) { return; } const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; // batch_id = batch index const int tmp_log_buf_index = batch_id * vocab_size; const int tmp_topk_buf_index = batch_id * BLOCKS_PER_BEAM_ * max_top_k + block_lane * k; TopK_2<T> partial; const bool IS_FP16 = std::is_same<T, half>::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; if (finished != nullptr && finished[batch_id] == true) { if (tid < k) { const int index = tmp_topk_buf_index + tid; if (block_lane == 0 && tid == 0) { const int end_id = end_ids[batch_id]; topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; } else { topk_tmp_id_buf[index] = -1; topk_tmp_val_buf[index] = -MAX_T_VAL; } } return; } for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { int index = elem_id + tmp_log_buf_index; tmp_log_probs[index] = log_probs[index]; } for (int ite = 0; ite < k; ite++) { partial.init(); #pragma unroll for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { int index = elem_id + tmp_log_buf_index; partial.insert(tmp_log_probs[index], index); } TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>); if (tid == 0) { const int index = tmp_topk_buf_index + ite; topk_tmp_id_buf[index] = total.p; topk_tmp_val_buf[index] = total.u; if (total.p >= 0) { tmp_log_probs[total.p] = -MAX_T_VAL; } } __syncthreads(); } } template<typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_, bool RECORD_PROB> __global__ void topk_stage2_sampling(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, int* sequence_length, bool* finished, float* cum_log_probs, float* output_log_probs, float* output_all_probs, const int max_top_k, const int* top_ks, const float top_p, const float* top_ps, curandState_t* curandstate, const int* end_ids, const int vocab_size, const bool* skip_decode) { const bool IS_FP16 = std::is_same<T, half>::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const int tid = threadIdx.x; const int batch_id = blockIdx.x; if (skip_decode != nullptr && skip_decode[batch_id]) { return; } const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; const float prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p; const int size = k * BLOCKS_PER_BEAM_; const int stride = max_top_k * BLOCKS_PER_BEAM_; typedef cub::BlockReduce<TopK_2<float>, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; extern __shared__ char array[]; __shared__ float rand_num; __shared__ float s_sum; __shared__ float s_max; T* s_val = topk_tmp_val_buf + batch_id * stride; int* s_id = reinterpret_cast<int*>(array); if (tid == 0) { s_sum = 0.0f; } TopK_2<float> partial; if (finished != nullptr && finished[batch_id] == true) { ids[batch_id] = end_ids[batch_id]; return; } float* s_val2 = reinterpret_cast<float*>(s_id + k); for (int ite = 0; ite < k; ite++) { partial.init(); #pragma unroll for (int i = tid; i < size; i += BLOCK_SIZE_) { partial.insert((float)s_val[i], i); } TopK_2<float> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<float>); if (tid == 0) { if (ite == 0) { s_max = total.u; } s_id[ite] = total.p; s_val[total.p] = -MAX_T_VAL; // when cum_log_probs are computed, topk_tmp_val_buf (logits_buf_) are already pre-processed by // softmax_kernel total.u = __expf(total.u - s_max); s_val2[ite] = total.u; s_sum += total.u; } __syncthreads(); } //@miji TODO: use block sum to make it faster if constexpr(RECORD_PROB) { float prob_sum = 0; if (threadIdx.x == 0) { for (int i = 0; i < k; i++) { int token_idx = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size; float origin_prob = __expf(logf(s_val2[i]) - logf(s_sum)); prob_sum += origin_prob; output_all_probs[batch_id * vocab_size + token_idx] = max(0.0, origin_prob - max(0.0, prob_sum - prob_threshold)) / prob_threshold; if (prob_sum >= prob_threshold) { break; } } } } if (tid == 0) { rand_num = (float)curand_uniform(curandstate + blockIdx.x) * prob_threshold * s_sum; for (int i = 0; i < k; i++) { float exp_logit = s_val2[i]; rand_num = rand_num - exp_logit; if (rand_num <= 0.0f || i == k - 1) { ids[batch_id] = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size; if (cum_log_probs != nullptr || output_log_probs != nullptr) { float log_prob = logf(exp_logit) - logf(s_sum); if (cum_log_probs != nullptr) { cum_log_probs[batch_id] += log_prob; } if (output_log_probs != nullptr) { // 'output_log_probs' is the probability induced by the top-k sampling. // We normalize the probability 'exp_logit' of the selected token by // the probability 's_sum' of a set of top-k tokens, meaning the log_prob // is the probability of the selected token, conditioned on the event that // it is selected, i.e., // log_prob = log P(i | i is in top-k) = log(exp_logit / s_sum). output_log_probs[batch_id] = log_prob; } } break; } } if (sequence_length != nullptr && finished != nullptr) { sequence_length[batch_id] = finished[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1; finished[batch_id] = ids[batch_id] == end_ids[batch_id] ? true : false; } } } #define CASE_K(K_MIN, K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_, RECORD_PROB) \ case K_MIN ... K_MAX: \ topk_stage1<T, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_> \ <<<batch_size * BLOCKS_PER_BEAM_, BLOCK_SIZE_1_, 0, stream>>>(log_probs, \ temp_log_probs, \ topk_tmp_id_buf, \ topk_tmp_val_buf, \ finished, \ max_top_k, \ top_ks, \ vocab_size, \ end_ids, \ skip_decode); \ topk_stage2_sampling<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_, RECORD_PROB> \ <<<batch_size, BLOCK_SIZE_2_, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(topk_tmp_id_buf, \ topk_tmp_val_buf, \ ids, \ sequence_length, \ finished, \ cum_log_probs, \ output_log_probs, \ output_all_probs, \ max_top_k, \ top_ks, \ top_p, \ top_ps, \ curandstate, \ end_ids, \ vocab_size, \ skip_decode); \ break; template<typename T> void invokeBatchTopKSampling(void* workspace, size_t& workspace_size, const T* log_probs, int* ids, int* sequence_length, bool* finished, float* cum_log_probs, float* output_log_probs, curandState_t* curandstate, const int max_top_k, const int* top_ks, const float top_p, const float* top_ps, const int vocab_size_padded, const int* end_ids, float* output_all_probs, cudaStream_t stream, const int batch_size, const bool* skip_decode) { // Not allow an ambiguous inputs top_p and top_ps. assert(top_p == 1.0f || top_ps == nullptr); const int vocab_size = vocab_size_padded; const int max_block_per_beam = 8; int temp_log_probs_buf_size = batch_size * vocab_size; // type float int topk_tmp_ids_buf_size = batch_size * max_top_k * max_block_per_beam; // type int int topk_tmp_val_buf_size = batch_size * max_top_k * max_block_per_beam; // type float // prevent memory misaligned address temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; if (workspace == nullptr) { workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size + sizeof(T) * topk_tmp_val_buf_size; return; } T* temp_log_probs = (T*)workspace; int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); #define SWITCH_MAX_K(LOG_PROB) \ switch (max_top_k) { \ CASE_K(1, 16, 128, 128, 8, LOG_PROB); \ CASE_K(17, 32, 256, 128, 8, LOG_PROB); \ CASE_K(33, 64, 256, 256, 8, LOG_PROB); \ CASE_K(65, 1024, 256, 256, 8, LOG_PROB); \ default: \ throw std::domain_error(rtp_llm::fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k)); \ } if (output_all_probs) { SWITCH_MAX_K(true); } else { SWITCH_MAX_K(false); } } #undef CASE_K template void invokeBatchTopKSampling(void* workspace, size_t& workspace_size, const float* log_probs, int* ids, int* sequence_length, bool* finished_buf, float* cum_log_probs, float* output_log_probs, curandState_t* curandstate, const int max_top_k, const int* top_ks, const float top_p, const float* top_ps, const int vocab_size_padded, const int* end_ids, float* output_all_probs, cudaStream_t stream, const int batch_size, const bool* skip_decode); template void invokeBatchTopKSampling(void* workspace, size_t& workspace_size, const half* log_probs, int* ids, int* sequence_length, bool* finished_buf, float* cum_log_probs, float* output_log_probs, curandState_t* curandstate, const int max_top_k, const int* top_ks, const float top_p, const float* top_ps, const int vocab_size_padded, const int* end_ids, float* output_all_probs, cudaStream_t stream, const int batch_size, const bool* skip_decode); template<typename T> void invokeTopKSampling(void* workspace, size_t& workspace_size, const T* log_probs, int* ids, int* sequence_length, bool* finished_buf, float* cum_log_probs, float* output_log_probs, curandState_t* curandstate, const int top_k, const float top_p, const int vocab_size_padded, const int* end_ids, float* output_all_probs, cudaStream_t stream, const int batch_size, const bool* skip_decode) { invokeBatchTopKSampling(workspace, workspace_size, log_probs, ids, sequence_length, finished_buf, cum_log_probs, output_log_probs, curandstate, top_k, nullptr, top_p, nullptr, vocab_size_padded, end_ids, output_all_probs, stream, batch_size, skip_decode); } template void invokeTopKSampling(void* workspace, size_t& workspace_size, const float* log_probs, int* ids, int* sequence_length, bool* finished_buf, float* cum_log_probs, float* output_log_probs, curandState_t* curandstate, const int top_k, const float top_p, const int vocab_size_padded, const int* end_ids, float* output_all_probs, cudaStream_t stream, const int batch_size, const bool* skip_decode); template void invokeTopKSampling(void* workspace, size_t& workspace_size, const half* log_probs, int* ids, int* sequence_length, bool* finished_buf, float* cum_log_probs, float* output_log_probs, curandState_t* curandstate, const int top_k, const float top_p, const int vocab_size_padded, const int* end_ids, float* output_all_probs, cudaStream_t stream, const int batch_size, const bool* skip_decode); template<uint TOP_K_MAX> __global__ void setup_topk_runtime_args(int batch_size, uint top_k, uint* top_ks, int top_ks_size, float top_p, float* top_ps, int top_ps_size, bool* skip_decode) { int index = blockIdx.x * blockDim.x + threadIdx.x; for (int i = index; i < batch_size; i += gridDim.x * blockDim.x) { uint k = top_ks_size > 1 ? top_ks[i] : top_k; float p = top_ps_size > 1 ? top_ps[i] : top_p; if (k == 0 && p == 0.0f) { // FT's topp implementation does not support topp = 0.0f, but it equivalent to greedy search. // So, we set the topk = 1 as an alternative solution. k = 1; } if (k > 0 && p == 0.0f) { // for compatibility <= FT5.0. // This case corresponds to the old topk sampling, which is equivalent to // the old topk_topp sampling with topp=1.0f. TopKSamplingLayer and // TopKTopPSamplingLayer are now merged by TopKSamplingLayer. Thus, we // replace the case topk>0 and topp=0.0f by topk>0 and topp=1.0f for the // compatibility. p = 1.0f; } // Clip k value. A topk sampling kernel supports up to TOP_K_MAX=64. top_ks[i] = k > TOP_K_MAX ? TOP_K_MAX : k; if (k > TOP_K_MAX) { printf("[WARNING] topk (%d) is larger than max supported number (%d) for token %d" " clip to max supported number %d. \n", k, TOP_K_MAX, i, top_ks[i]); } // Clip p value if it is out of range. range = [0.0, 1.0]. top_ps[i] = p < 0.0f ? 0.0f : (p > 1.0f ? 1.0f : p); if (p < 0.0f || p > 1.0f) { printf("[WARNING] topp (%f) is out of range ([0.0, 1.0f]) for token %d" " clip to closest number %f.\n", p, i, top_ps[i]); } skip_decode[i] = k == 0; } } void invokeSetupTopKRuntimeArgs(int batch_size, uint top_k, uint* top_ks, int top_ks_size, float top_p, float* top_ps, int top_ps_size, bool* skip_decode, cudaStream_t stream) { dim3 block(std::min((int)batch_size, 256)); dim3 grid(div_up((int)batch_size, (int)block.x)); // support top_k up to 1024. setup_topk_runtime_args<1024><<<grid, block, 0, stream>>>(batch_size, top_k, top_ks, top_ks_size, top_p, top_ps, top_ps_size, skip_decode); } } // namespace rtp_llm