maga_transformer/cpp/kernels/sampling_topp_kernels.cu (1,458 lines of code) (raw):
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if USING_CUDA
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "maga_transformer/cpp/cuda/cuda_utils.h"
#endif
#if USING_ROCM
#include <hipcub/hipcub.hpp>
#include "maga_transformer/cpp/rocm/hip_utils.h"
using namespace rtp_llm::rocm;
#endif
#include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh"
#include "maga_transformer/cpp/kernels/sampling_topp_kernels.h"
#include "maga_transformer/cpp/kernels/sampling_penalty_kernels.h"
constexpr int ENABLE_SINGLE_PASS_TOP_P = 0;
constexpr float SINGLE_PASS_THRESHOLD = 0.9;
namespace rtp_llm {
namespace segmented_topp_impl {
template<int HALF_ELEMENTS_PER_WARP_LOAD>
using Copy_half_t = typename std::conditional<
HALF_ELEMENTS_PER_WARP_LOAD == 32,
half,
typename std::conditional<HALF_ELEMENTS_PER_WARP_LOAD == 64,
int,
typename std::conditional<HALF_ELEMENTS_PER_WARP_LOAD == 128, int2, int4>::type>::type>::
type;
template<typename T, int ELEMENTS_PER_WARP_LOAD>
using Copy_t = Copy_half_t<sizeof(T) / sizeof(half) * ELEMENTS_PER_WARP_LOAD>;
template<typename T>
struct Float_as_int_ {
};
template<>
struct Float_as_int_<float> {
using Type = uint32_t;
};
template<>
struct Float_as_int_<__half> {
using Type = uint16_t;
};
using kernel_params_float = Segmented_topk_kernel_params<float, int32_t, 256, 2>;
using kernel_params_float_1 = Segmented_topk_kernel_params<float, int32_t, 256, 1>;
using kernel_params_half = Segmented_topk_kernel_params<__half, int32_t, 256, 4>;
using kernel_params_half_1 = Segmented_topk_kernel_params<__half, int32_t, 256, 1>;
///////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float to_float(uint32_t src)
{
return __int_as_float(src);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float to_float(uint16_t src)
{
__half dst = __ushort_as_half(src);
return __half2float(dst);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// sort one segment per cta
template<typename T_SCORE, int BLOCK_THREADS, int ELEMENTS_PER_THREAD>
__global__ void blockSortKernel(const T_SCORE* d_keys_in,
T_SCORE* d_keys_out,
const int32_t* d_values_in,
int32_t* d_values_out,
const int32_t* active_counts,
int num_items_,
int stride_items,
int num_segments)
{
// Specialize BlockRadixSort for a 1D block
typedef cub::BlockRadixSort<T_SCORE, BLOCK_THREADS, ELEMENTS_PER_THREAD, int32_t> BlockRadixSort;
// Allocate shared memory for BlockRadixSort
__shared__ typename BlockRadixSort::TempStorage temp_storage;
if (blockIdx.x >= num_segments) {
return;
}
int num_items = active_counts[blockIdx.x]; // > num_items_ ? num_items_ : active_counts[blockIdx.x];
if (num_items == 0) {
return;
}
// Obtain a segment of consecutive items that are blocked across threads
T_SCORE thread_keys[ELEMENTS_PER_THREAD];
int32_t thread_values[ELEMENTS_PER_THREAD];
int32_t block_offset = blockIdx.x * stride_items;
cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_out + block_offset, thread_keys, num_items, 0);
cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_values_out + block_offset, thread_values, num_items, -1);
__syncthreads();
// Collectively sort the keys and values among block threads
BlockRadixSort(temp_storage).SortDescendingBlockedToStriped(thread_keys, thread_values);
// Store output in striped fashion
cub::StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_out + block_offset, thread_keys, num_items);
cub::StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_values_out + block_offset, thread_values, num_items);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// block sort kernel
template<typename T_SCORE>
void blockSort(const T_SCORE* d_keys_in,
T_SCORE* d_keys_out,
const int32_t* d_values_in,
int32_t* d_values_out,
const int32_t* active_counts,
int num_items,
int stride_items,
int num_segments,
cudaStream_t stream)
{
if (num_items == 0) {
return;
}
int kernel_index = div_up(num_items, 128) - 1;
int warps_per_cta = (kernel_index + 1) * 128 / 32;
if (kernel_index > 7) {
kernel_index = 7 + div_up(num_items, 1024) - 1;
warps_per_cta = 1024 / 32;
}
assert(warps_per_cta <= 32);
dim3 block(warps_per_cta * 32);
dim3 grid(num_segments);
using kernel_func = void (*)(const T_SCORE* d_keys_in,
T_SCORE* d_keys_out,
const int32_t* d_values_in,
int32_t* d_values_out,
const int32_t* active_counts,
int num_items,
int stride_items,
int num_segments);
static const kernel_func kernel_funcs[] = {
&blockSortKernel<T_SCORE, 128, 1>,
&blockSortKernel<T_SCORE, 256, 1>,
&blockSortKernel<T_SCORE, 384, 1>,
&blockSortKernel<T_SCORE, 512, 1>,
&blockSortKernel<T_SCORE, 640, 1>,
&blockSortKernel<T_SCORE, 768, 1>,
&blockSortKernel<T_SCORE, 896, 1>,
&blockSortKernel<T_SCORE, 1024, 1>,
&blockSortKernel<T_SCORE, 1024, 2>,
&blockSortKernel<T_SCORE, 1024, 4>,
//&blockSortKernel<T_SCORE, 1024, 6>,
};
kernel_funcs[kernel_index]<<<grid, block, 0, stream>>>(
d_keys_in, d_keys_out, d_values_in, d_values_out, active_counts, num_items, stride_items, num_segments);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct BlockPrefixCallbackOp {
// Running prefix
int running_total;
// Constructor
__device__ BlockPrefixCallbackOp(uint32_t running_total): running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ int operator()(uint32_t block_aggregate)
{
uint32_t old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
#define DO_DEBUG_PRINT 0
// governs the split between regs and smem
constexpr float SMEM_FRACTION = 0.5F;
constexpr float P_EPSILON = 0.01F;
constexpr int MAX_TOP_K = 3072;
constexpr int WARP_SZ = 32;
template<typename Kernel_params, int ITEMS_PER_THREAD>
__global__ __launch_bounds__(Kernel_params::BLOCK_THREADS,
1) void segmented_top_p_single_pass(TopKPerSegmentParams params)
{
#if DO_DEBUG_PRINT
constexpr int debug_block_id = 26;
#endif
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
using Int_Key_Data_Type = typename Float_as_int_<Key_Data_Type>::Type;
// 4 fp16 keys or 2 fp32 keys
constexpr int KEYS_PER_LDG = Kernel_params::KEYS_PER_LDG;
typedef Copy_t<Key_Data_Type, WARP_SZ * KEYS_PER_LDG> copy_t;
union access_t {
copy_t v;
Int_Key_Data_Type x[KEYS_PER_LDG]; // supported size 1,2,4
};
constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;
constexpr int ITEMS_PER_THREAD_IN_REGS = ITEMS_PER_THREAD * (1.0F - SMEM_FRACTION);
constexpr int ITEMS_PER_THREAD_IN_SMEM = ITEMS_PER_THREAD - ITEMS_PER_THREAD_IN_REGS;
#if DO_DEBUG_PRINT == 1
if (blockIdx.x == 0 && threadIdx.x == 0) {
printf("ITEMS_PER_THREAD, ITEMS_PER_THREAD_IN_REGS, ITEMS_PER_THREAD_IN_SMEM = %d, %d, %d\n",
ITEMS_PER_THREAD,
ITEMS_PER_THREAD_IN_REGS,
ITEMS_PER_THREAD_IN_SMEM);
}
#endif
constexpr int MIN_KEY = 0;
constexpr int ENABLED_PER_THREAD = (ITEMS_PER_THREAD + 32 - 1) / 32;
extern __shared__ int2 dynamic_smem[];
int2* smem_selected_elements = dynamic_smem;
Int_Key_Data_Type* smem_thread_items = reinterpret_cast<Int_Key_Data_Type*>(smem_selected_elements + MAX_TOP_K);
__shared__ unsigned int smem_selected_count;
// Specialize BlockScan type for our thread block
typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;
// Specialize BlockScan type for our thread block
typedef cub::BlockReduce<float, BLOCK_THREADS> BlockReduce;
__shared__ float smem_p_sum_total;
__shared__ union {
typename BlockScan::TempStorage scan;
typename BlockReduce::TempStorage reduce;
} temp_storage;
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
unsigned int old_selected_count;
uint32_t segment = blockIdx.y * gridDim.x + blockIdx.x;
// Preceding TopK has shortcutted this segment
if (params.gmem_begin_offsets[segment] == params.gmem_end_offsets[segment]) {
if (threadIdx.x == 0) {
params.gmem_active_count_per_segment[segment] = 1;
atomicMax(params.gmem_active_count_total, 1);
}
return;
}
Int_Key_Data_Type* gmem_src_keys = reinterpret_cast<Int_Key_Data_Type*>(params.gmem_src_keys);
Int_Key_Data_Type* gmem_dst_keys = reinterpret_cast<Int_Key_Data_Type*>(params.gmem_dst_keys);
int32_t* gmem_dst_vals = reinterpret_cast<int32_t*>(params.gmem_dst_vals);
constexpr int BITS_IN_KEY = sizeof(Key_Data_Type) * 8;
int items = params.num_items / params.num_segments;
int first_index = segment * items;
gmem_src_keys += first_index;
gmem_dst_keys += first_index;
gmem_dst_vals += first_index;
int index_limit = items;
Int_Key_Data_Type thread_items[ITEMS_PER_THREAD_IN_REGS] = {0};
// Load all keys into registers and smem
const int lane_id = threadIdx.x % WARP_SZ;
const int warp_id = threadIdx.x / WARP_SZ;
constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SZ;
access_t ZERO;
for (int i = 0; i < KEYS_PER_LDG; i++) {
ZERO.x[i] = MIN_KEY;
}
// registers
for (int iter = 0; iter < ITEMS_PER_THREAD_IN_REGS; iter++) {
int offset = (iter + threadIdx.x * ITEMS_PER_THREAD);
thread_items[iter] = (offset < index_limit) ? gmem_src_keys[offset] : MIN_KEY;
}
// shared memory
for (int c = warp_id; c < BLOCK_THREADS; c += NUM_WARPS) {
for (int iter = lane_id * KEYS_PER_LDG; iter < ITEMS_PER_THREAD_IN_SMEM; iter += WARP_SZ * KEYS_PER_LDG) {
int offset = iter + c * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS;
access_t val;
val.v = (offset < index_limit) ? *reinterpret_cast<copy_t*>(&gmem_src_keys[offset]) : ZERO.v;
for (int i = 0; i < KEYS_PER_LDG; i++) {
smem_thread_items[c + (iter + i) * BLOCK_THREADS] = val.x[i];
}
// smem_thread_items[c + iter * BLOCK_THREADS] = (offset < index_limit)? gmem_src_keys[offset] : MIN_KEY;
}
}
Int_Key_Data_Type select_mask = 0;
Int_Key_Data_Type save_mask = 0;
// Int_Key_Data_Type save_bit = 0;
// set to true when we finish with too few keys, so we go back to last_save_mask one more time
bool is_last_iter = false;
if (threadIdx.x == 0) {
smem_selected_count = 0;
old_selected_count = 0;
}
// iterate over bits.
// skip the first two bits,
// * bit 31 is the sign bit. all values are positive
// * bit 30 is only set for values >= 2, but the input consists only of values in
// the range of [0,1]
constexpr int START_BIT = BITS_IN_KEY - 1;
constexpr int SKIP_BITS = 2;
constexpr Int_Key_Data_Type ONE = (Int_Key_Data_Type)1;
uint32_t selected;
uint32_t sc;
float p_sum_total = 0.0F;
float old_p_sum_total = 0.0F;
uint32_t offset = 0;
for (Int_Key_Data_Type bit = START_BIT - SKIP_BITS; true; --bit) {
__syncthreads();
Int_Key_Data_Type bit_mask = select_mask | (ONE << bit);
uint32_t enabled[ENABLED_PER_THREAD] = {0};
float thread_sum = 0.0F;
for (int item = 0; item < ITEMS_PER_THREAD_IN_REGS; ++item) {
// check if all the bits from bit mask are contained in the thread_item. If yes, set respective
// bit of enabled
auto val = thread_items[item];
uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0);
// thread_sum += (is_enabled)? to_float(val) : 0.0F;
thread_sum += is_enabled * to_float(val);
enabled[item / 32] |= is_enabled << (item % 32);
}
for (int item = 0; item < ITEMS_PER_THREAD_IN_SMEM; ++item) {
int idx = threadIdx.x + item * BLOCK_THREADS;
// int idx = item + ITEMS_PER_THREAD_IN_SMEM * threadIdx.x;
auto val = smem_thread_items[idx];
uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0);
// thread_sum += (is_enabled)? to_float(val) : 0.0F;
thread_sum += is_enabled * to_float(val);
enabled[(ITEMS_PER_THREAD_IN_REGS + item) / 32] |= is_enabled << ((ITEMS_PER_THREAD_IN_REGS + item) % 32);
}
selected = 0;
#pragma unroll
for (int i = 0; i < ENABLED_PER_THREAD; i++) {
selected += __popc(enabled[i]);
}
float p_sum = BlockReduce(temp_storage.reduce).Sum(thread_sum);
if (threadIdx.x == 0) {
p_sum_total += p_sum;
smem_p_sum_total = p_sum_total;
}
__syncthreads();
p_sum_total = smem_p_sum_total;
__syncthreads();
BlockScan(temp_storage.scan).ExclusiveSum(selected, offset, prefix_op);
if (threadIdx.x == 0) {
smem_selected_count = prefix_op.running_total;
}
__syncthreads();
sc = smem_selected_count;
__syncthreads();
// float p_diff = params.top_p - p_sum_total;
float p_diff = p_sum_total - params.top_p;
if ((p_sum_total <= params.top_p + P_EPSILON && p_sum_total > 0)
|| (p_sum_total > params.top_p && sc <= MAX_TOP_K) || (bit == 0 && p_sum_total > 0) || is_last_iter) {
#if DO_DEBUG_PRINT == 1
__syncthreads();
if (threadIdx.x == 0 && blockIdx.x == debug_block_id) {
sc = smem_selected_count;
printf("bit %d bit_mask %d offset %d (%d, %d), sc = %d, p_sum = %f, p_sum_total = %f\n",
bit,
bit_mask,
offset,
blockIdx.x,
threadIdx.x,
sc,
p_sum,
p_sum_total);
}
__syncthreads();
#endif
for (int item = 0; item < ITEMS_PER_THREAD_IN_REGS; ++item) {
// last condition should not trigger with well trained weights, but we will get
// illegal mewmory access if we do not have one in those rare cases
if (enabled[item / 32] & (ONE << (item % 32)) && offset < MAX_TOP_K) {
smem_selected_elements[offset] =
make_int2(thread_items[item], item + threadIdx.x * ITEMS_PER_THREAD);
++offset;
thread_items[item] = MIN_KEY;
}
}
for (int item = 0; item < ITEMS_PER_THREAD_IN_SMEM; ++item) {
if (enabled[(item + ITEMS_PER_THREAD_IN_REGS) / 32] & (ONE << ((item + ITEMS_PER_THREAD_IN_REGS) % 32))
&& offset < MAX_TOP_K) {
int idx = threadIdx.x + item * BLOCK_THREADS;
// int idx = item + ITEMS_PER_THREAD_IN_SMEM * threadIdx.x;
// if (idx < params.num_items_per_segment_in_smem)
{
smem_selected_elements[offset] = make_int2(
smem_thread_items[idx], item + threadIdx.x * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS);
++offset;
smem_thread_items[idx] = MIN_KEY;
}
}
}
}
#if DO_DEBUG_PRINT == 1
if (threadIdx.x == 0 && blockIdx.x == debug_block_id) {
printf("!!!! bit %d bit_mask %d offset %d (%d, %d), sc = %d, p_sum = %f, p_sum_total = %f\n",
bit,
bit_mask,
offset,
blockIdx.x,
threadIdx.x,
sc,
p_sum,
p_sum_total);
}
#endif
if (p_diff <= P_EPSILON && p_diff >= 0 || (p_sum_total > params.top_p && sc <= MAX_TOP_K) || bit == 0) {
break;
}
// p > top_p
else if (p_diff > P_EPSILON) {
// There are too many bits in the current selection
// Save the current state and go to the next bit
// If there are not enough items left using the next bit
// it's necessary to restart here with the current bit not set
save_mask = bit_mask;
select_mask |= bit_mask;
if (threadIdx.x == 0) {
smem_selected_count = old_selected_count;
p_sum_total = old_p_sum_total;
prefix_op.running_total = old_selected_count;
}
}
else {
// sc < num_top_k branch
if (save_mask) {
select_mask = save_mask;
save_mask = 0;
}
if (threadIdx.x == 0) {
old_selected_count = smem_selected_count;
old_p_sum_total = p_sum_total;
}
}
}
__syncthreads();
// store data to global memory
sc = (p_sum_total < params.top_p) ? params.num_items / params.num_segments : smem_selected_count;
if (threadIdx.x == 0) {
params.gmem_active_count_per_segment[segment] = sc;
atomicMax(params.gmem_active_count_total, sc);
}
if (sc >= MAX_TOP_K) {
return;
}
for (int i = threadIdx.x; i < sc; i += blockDim.x) {
int2 selected_element = smem_selected_elements[i];
gmem_dst_keys[i] = selected_element.x;
gmem_dst_vals[i] = selected_element.y;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
int getSmemSizeAndCheck(const TopKPerSegmentContext& context, const TopKPerSegmentParams& params)
{
constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
int num_items_per_segment = params.num_items / params.num_segments;
constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT;
int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1;
int smem_size = MAX_TOP_K * sizeof(int2);
const int items_per_thread = (kernel_index + 1) * ITEMS_INCREMENT;
const int items_per_thread_in_regs = items_per_thread * (1.0F - SMEM_FRACTION);
const int items_per_thread_in_smem = items_per_thread - items_per_thread_in_regs;
smem_size += items_per_thread_in_smem * BLOCK_THREADS * sizeof(typename Float_as_int_<Key_Data_Type>::Type);
int keys_per_ldg = 2 * sizeof(Key_Data_Type) / 2;
if (smem_size + BLOCK_THREADS * sizeof(float) > (size_t)context.sm_shared_size || // dynamic + static memory
items_per_thread_in_regs + items_per_thread_in_smem != items_per_thread || params.top_p + P_EPSILON > 1.0F
|| items_per_thread_in_regs % keys_per_ldg != 0 || items_per_thread_in_smem % keys_per_ldg != 0
|| num_items_per_segment % keys_per_ldg != 0) {
return -1;
}
return smem_size;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int getSmemSizeAndCheck(const TopKPerSegmentContext& context,
const TopKPerSegmentParams& params,
const DType_t DT_SCORE)
{
int num_items_per_segment = params.num_items / params.num_segments;
if (DT_SCORE == kFLOAT) {
if (num_items_per_segment % 2 == 0) {
return getSmemSizeAndCheck<kernel_params_float>(context, params);
}
else {
return getSmemSizeAndCheck<kernel_params_float_1>(context, params);
}
}
else {
if (num_items_per_segment % 4 == 0) {
return getSmemSizeAndCheck<kernel_params_half>(context, params);
}
else {
return getSmemSizeAndCheck<kernel_params_half_1>(context, params);
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
void segmentedTopPSinglePass_dispatch(const TopKPerSegmentParams& params,
const TopKPerSegmentContext& context,
cudaStream_t stream)
{
constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
using Value_Data_Type = typename Kernel_params::Value_Data_Type;
int num_items_per_segment = params.num_items / params.num_segments;
constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT;
int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1;
#define KERNEL_RUN(INDEX) \
{ \
if (smem_size > 0) \
check_cuda_error( \
cudaFuncSetAttribute((const void*)segmented_top_p_single_pass<Kernel_params, ITEMS_INCREMENT*(INDEX + 1)>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
smem_size)); \
segmented_top_p_single_pass<Kernel_params, ITEMS_INCREMENT*(INDEX + 1)> \
<<<grid_dim, Kernel_params::BLOCK_THREADS, smem_size, stream>>>(params); \
}
int smem_size = getSmemSizeAndCheck<Kernel_params>(context, params);
dim3 grid_dim(params.num_segments, 1);
switch (kernel_index) {
case 0:
KERNEL_RUN(0) break;
case 1:
KERNEL_RUN(1) break;
case 2:
KERNEL_RUN(2) break;
case 3:
KERNEL_RUN(3) break;
case 4:
KERNEL_RUN(4) break;
case 5:
KERNEL_RUN(5) break;
case 6:
KERNEL_RUN(6) break;
case 7:
KERNEL_RUN(7) break;
default:
exit(1);
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
void topPPerSegment_dispatch(const TopKPerSegmentContext& context,
TopKPerSegmentParams& params,
void* temp_storage,
size_t& temp_storage_bytes,
cudaStream_t stream)
{
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
using Value_Data_Type = typename Kernel_params::Value_Data_Type;
if (temp_storage == nullptr) {
if (params.num_segments > 1) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(temp_storage,
temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
params.num_segments,
params.gmem_begin_offsets,
params.gmem_end_offsets,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
else {
cub::DeviceRadixSort::SortPairsDescending(temp_storage,
temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256;
// total active counts
temp_storage_bytes += div_up(sizeof(int), 256) * 256;
// storage for gmem_end_offsets
temp_storage_bytes += div_up(sizeof(int) * params.num_segments, 256) * 256;
return;
}
size_t cub_temp_storage_bytes =
temp_storage_bytes - div_up(sizeof(int), 256) * 256 - div_up(sizeof(int) * params.num_segments, 256) * 256;
void* cub_temp_storage = temp_storage;
params.gmem_active_count_total = reinterpret_cast<int*>((char*)temp_storage + cub_temp_storage_bytes);
params.gmem_active_count_per_segment =
reinterpret_cast<int*>((char*)params.gmem_active_count_total + div_up(sizeof(int), 256) * 256);
int num_items_per_segment = params.num_items / params.num_segments;
cudaMemsetAsync(params.gmem_active_count_total, 0, sizeof(int), stream);
cudaMemsetAsync(params.gmem_dst_keys, 0, params.num_items * sizeof(Key_Data_Type), stream);
segmentedTopPSinglePass_dispatch<Kernel_params>(params, context, stream);
int max_num_items = 0;
cudaMemcpyAsync(&max_num_items, params.gmem_active_count_total, sizeof(int), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
if (max_num_items >= MAX_TOP_K || max_num_items == 0) {
if (params.num_segments > 1) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage,
cub_temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
params.num_segments,
params.gmem_begin_offsets,
params.gmem_end_offsets,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
else {
cub::DeviceRadixSort::SortPairsDescending(cub_temp_storage,
cub_temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
}
else {
// run at max supported value
blockSort<Key_Data_Type>((const Key_Data_Type*)(params.gmem_dst_keys),
(Key_Data_Type*)(params.gmem_dst_keys),
(const Value_Data_Type*)(params.gmem_dst_vals),
(Value_Data_Type*)(params.gmem_dst_vals),
params.gmem_active_count_per_segment,
max_num_items,
num_items_per_segment,
params.num_segments,
stream);
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int topPPerSegment(const TopKPerSegmentContext& context,
TopKPerSegmentParams& params,
const DType_t DT_SCORE,
void* temp_storage,
size_t& temp_storage_bytes,
cudaStream_t stream)
{
int num_items_per_segment = params.num_items / params.num_segments;
if (DT_SCORE == kFLOAT) {
if (num_items_per_segment % 2 == 0) {
topPPerSegment_dispatch<kernel_params_float>(context, params, temp_storage, temp_storage_bytes, stream);
}
else {
topPPerSegment_dispatch<kernel_params_float_1>(context, params, temp_storage, temp_storage_bytes, stream);
}
}
else {
if (num_items_per_segment % 4 == 0) {
topPPerSegment_dispatch<kernel_params_half>(context, params, temp_storage, temp_storage_bytes, stream);
}
else {
topPPerSegment_dispatch<kernel_params_half_1>(context, params, temp_storage, temp_storage_bytes, stream);
}
}
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace segmented_topp_impl
__global__ void topPInitialize(
int* topp_id_val_buf, int* topp_offset_buf, int* begin_topp_offset_buf_, const int batch_size, const int n)
{
int tid = threadIdx.x;
int bid = blockIdx.x;
if (bid == 0) {
for (int i = tid; i < batch_size + 1; i += blockDim.x) {
topp_offset_buf[i] = i * n;
begin_topp_offset_buf_[i] = topp_offset_buf[i];
}
}
int index = tid + bid * blockDim.x;
while (index < batch_size * n) {
topp_id_val_buf[index] = index % n;
index += blockDim.x * gridDim.x;
}
}
void invokeTopPInitialize(int* topp_id_val_buf,
int* topp_offset_buf,
int* begin_topp_offset_buf_,
const size_t batch_size,
const int n,
cudaStream_t stream)
{
// n: the column number of logits_buffer for top_p sampling
topPInitialize<<<32, 512, 0, stream>>>(topp_id_val_buf, topp_offset_buf, begin_topp_offset_buf_, batch_size, n);
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void topp_beam_topk_kernel(const T* log_probs, // prob.
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const int vocab_size,
int* offset_buf,
int* begin_offset_buf,
const float top_p,
const float* top_ps,
const bool* skip_decode)
{
int thread_id = threadIdx.x;
int batch_id = blockIdx.x;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
float p_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopK<T, MAX_K> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
#pragma unroll
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -MAX_T_VAL;
}
#pragma unroll
for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
int index = elem_id + batch_id * vocab_size;
partial.insert(log_probs[index], index);
}
TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);
if (thread_id == 0) {
begin_offset_buf[batch_id] = offset_buf[batch_id];
T sum_prob = (T)(0.0f);
#pragma unroll
for (int i = 0; i < MAX_K; i++) {
sum_prob += total.u[i];
}
if ((float)sum_prob >= p_threshold) {
begin_offset_buf[batch_id] += vocab_size;
int index = batch_id * vocab_size;
#pragma unroll
for (int i = 0; i < MAX_K; ++i) {
topk_tmp_id_buf[index + i] = total.p[i] % vocab_size;
topk_tmp_val_buf[index + i] = total.u[i];
}
}
}
}
struct BlockPrefixCallbackOp {
// Running prefix
float running_total;
// Constructor
__device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ float operator()(float block_aggregate)
{
float old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template<typename T, int BLOCK_SIZE>
__global__ void set_out_prob_kernel(float* out_prob, const T* sorted_log_probs, const int* sorted_id_vals, const bool* skip_decode, const float* top_ps, const float top_p, const int batch_size, const int vocab_size) {
__shared__ int stop_shared;
const int tid = threadIdx.x;
const int batch_id = blockIdx.x;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
constexpr int WARP_SIZE = 32;
constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const float prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ uint32_t selected_shared[NUM_WARPS];
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
if (lane_id == 0) {
selected_shared[warp_id] = 0;
}
__syncthreads();
int offset = batch_id * vocab_size;
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int total_num = 0;
float thread_offset = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) {
total_num += BLOCK_SIZE;
float thread_count = (i < vocab_size) ? (float)sorted_log_probs[offset + i] : 0.f;
BlockScan(temp_storage).InclusiveSum(thread_count, thread_offset, prefix_op);
uint32_t active_mask = __ballot_sync(0xFFFFFFFF, prob_threshold <= thread_offset);
if (active_mask != 0) {
if (lane_id == 0) {
atomicAdd(&stop_shared, 1);
}
}
if (i < vocab_size) {
int token_idx = sorted_id_vals[offset + i];
out_prob[offset + token_idx] = max(0.0, thread_count - max(0.0, thread_offset - prob_threshold)) / prob_threshold;
}
__syncthreads();
if (stop_shared > 0) {
break;
}
}
}
template<typename T, int BLOCK_SIZE>
__global__ void topp_sampling(T* sorted_log_probs,
int* sorted_id_vals,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const int* begin_offset_buf,
const int* offset_buf,
const int vocab_size,
curandState_t* curandstate,
const float top_p,
const float* top_ps,
const int* end_ids,
const int batch_size,
const bool* skip_decode)
{
__shared__ int stop_shared;
__shared__ float rand_num_s;
const int tid = threadIdx.x;
const int batch_id = blockIdx.x;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
constexpr int WARP_SIZE = 32;
constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const float prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;
if (threadIdx.x == 0) {
stop_shared = 0;
rand_num_s = curand_uniform(curandstate + blockIdx.x) * prob_threshold;
}
// if begin_offset_buf and offset_buf of sorting have same value,
// this means that we have find best one in beam_topK_kernel_for_topP
// and skip the sorting. So, we can skip then during sampling.
if (begin_offset_buf[batch_id] == offset_buf[batch_id]) {
if (tid == 0) {
int offset = batch_id * vocab_size;
ids[batch_id] = sorted_id_vals[offset];
if (cum_log_probs != nullptr || output_log_probs != nullptr) {
float lprob = logf(sorted_log_probs[offset]);
if (cum_log_probs != nullptr) {
cum_log_probs[batch_id] += lprob;
}
if (output_log_probs != nullptr) {
output_log_probs[batch_id] = lprob;
}
}
if (sequence_length != nullptr && finished_buf != nullptr) {
sequence_length[batch_id] =
finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0;
}
}
return;
}
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ uint32_t selected_shared[NUM_WARPS];
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
if (lane_id == 0) {
selected_shared[warp_id] = 0;
}
__syncthreads();
int offset = batch_id * vocab_size;
ids[batch_id] = sorted_id_vals[offset];
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int i_active = 0;
float thread_offset = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_count = (i < vocab_size) ? (float)sorted_log_probs[offset + i] : 0.f;
BlockScan(temp_storage).InclusiveSum(thread_count, thread_offset, prefix_op);
uint32_t active_mask = __ballot_sync(0xFFFFFFFF, rand_num_s <= thread_offset);
i_active = i;
if (active_mask != 0) {
if (lane_id == 0) {
atomicAdd(&stop_shared, 1);
selected_shared[warp_id] = active_mask;
}
}
__syncthreads();
if (stop_shared > 0) {
break;
}
};
// select first active warp
bool skip = (selected_shared[warp_id] > 0) ? false : true;
for (int i = 0; i < warp_id; i++) {
if (selected_shared[i] != 0) {
skip = true;
}
}
if (!skip) {
int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]);
if (lane_id == active_lane_id) {
ids[batch_id] = sorted_id_vals[offset + i_active];
if (cum_log_probs != nullptr || output_log_probs != nullptr) {
float lprob = logf(sorted_log_probs[offset + i_active]);
if (cum_log_probs != nullptr) {
cum_log_probs[batch_id] += lprob;
}
if (output_log_probs != nullptr) {
output_log_probs[batch_id] = lprob;
}
}
if (sequence_length != nullptr && finished_buf != nullptr) {
sequence_length[batch_id] =
finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0;
}
}
}
}
template<typename T>
void invokeBatchTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const T* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float max_top_p,
const float* top_ps,
float* output_all_probs,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode)
{
// Here, we put batch size as an argument because the batch size of initialization
// and inference may be different due to pipeline parallelism.
const int vocab_size = vocab_size_padded;
const int block_size = 256;
size_t sorted_log_prob_buf_size = batch_size * vocab_size * sizeof(T); // type T
size_t sorted_id_vals_buf_size = batch_size * vocab_size * sizeof(int); // type int
sorted_log_prob_buf_size = div_up(sorted_log_prob_buf_size, 256) * 256;
sorted_id_vals_buf_size = div_up(sorted_id_vals_buf_size, 256) * 256;
void* cub_temp_storage = workspace;
T* sorted_log_probs = (T*)((char*)cub_temp_storage + cub_temp_storage_size);
int* sorted_id_vals = (int*)((char*)sorted_log_probs + sorted_log_prob_buf_size);
bool do_radix_sort = (ENABLE_SINGLE_PASS_TOP_P == 0 || max_top_p >= SINGLE_PASS_THRESHOLD);
int smem_size = -1;
segmented_topp_impl::TopKPerSegmentContext context;
segmented_topp_impl::TopKPerSegmentParams params;
segmented_topp_impl::DType_t dataTypeKind =
(std::is_same<T, float>::value) ? segmented_topp_impl::kFLOAT : segmented_topp_impl::kHALF;
if (!do_radix_sort) {
RTP_LLM_CHECK(cuda_device_prop != nullptr);
memset(&context, 0, sizeof(context));
context.sm_count = cuda_device_prop->multiProcessorCount;
context.sm_shared_size = cuda_device_prop->sharedMemPerMultiprocessor;
context.sm_version = cuda_device_prop->major * 100 + cuda_device_prop->minor * 10;
memset(¶ms, 0, sizeof(params));
params.gmem_src_keys = reinterpret_cast<void*>(const_cast<T*>(log_probs));
params.gmem_dst_keys = sorted_log_probs;
params.gmem_src_vals = reinterpret_cast<void*>(const_cast<int*>(id_vals));
params.gmem_dst_vals = reinterpret_cast<void*>(sorted_id_vals);
params.gmem_begin_offsets = begin_offset_buf;
params.gmem_end_offsets = offset_buf + 1;
params.workspace = nullptr;
params.num_items = vocab_size * batch_size;
params.num_segments = batch_size;
params.top_p = max_top_p;
params.confidence_threshold = 0.0F;
smem_size = getSmemSizeAndCheck(context, params, dataTypeKind);
do_radix_sort = smem_size < 0;
}
if (do_radix_sort) {
if (workspace == nullptr) {
check_cuda_error(
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
cub_temp_storage_size,
log_probs,
(T*)nullptr,
id_vals,
(int*)nullptr,
vocab_size * batch_size,
batch_size,
begin_offset_buf,
offset_buf + 1,
0, // begin_bit
sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8
stream)); // cudaStream_t
cub_temp_storage_size = div_up(cub_temp_storage_size, 256) * 256;
workspace_size = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size;
return;
}
//
if (!output_all_probs) {
topp_beam_topk_kernel<T, 1, block_size><<<batch_size, block_size, 0, stream>>>(log_probs,
sorted_id_vals,
sorted_log_probs,
vocab_size,
offset_buf,
begin_offset_buf,
max_top_p,
top_ps,
skip_decode);
}
check_cuda_error(
cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage,
cub_temp_storage_size,
log_probs,
sorted_log_probs,
id_vals,
sorted_id_vals,
vocab_size * batch_size,
batch_size,
begin_offset_buf,
offset_buf + 1,
0, // begin_bit
sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8
stream)); // cudaStream_t
}
else {
if (workspace == nullptr) {
segmented_topp_impl::topPPerSegment(
context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream);
workspace_size = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size;
return;
}
else {
if (!output_all_probs) {
topp_beam_topk_kernel<T, 1, block_size><<<batch_size, block_size, 0, stream>>>(log_probs,
sorted_id_vals,
sorted_log_probs,
vocab_size,
offset_buf,
begin_offset_buf,
max_top_p,
top_ps,
skip_decode);
}
segmented_topp_impl::topPPerSegment(
context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream);
}
}
if (output_all_probs) {
set_out_prob_kernel<T, 256><<<batch_size, 256, 0, stream>>>(output_all_probs,
sorted_log_probs,
sorted_id_vals,
skip_decode,
top_ps,
max_top_p,
batch_size,
vocab_size);
}
constexpr int SAMPLING_BLOCK_SIZE = 256;
dim3 grid(batch_size);
topp_sampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sorted_log_probs,
sorted_id_vals,
output_ids,
sequence_length,
finished_buf,
cum_log_probs,
output_log_probs,
begin_offset_buf,
offset_buf + 1,
vocab_size,
curandstate,
max_top_p,
top_ps,
end_ids,
batch_size,
skip_decode);
}
template void invokeBatchTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const float* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float max_top_p,
const float* top_ps,
float* output_all_probs,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template void invokeBatchTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const half* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float max_top_p,
const float* top_ps,
float* output_all_probs,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template<typename T>
void invokeTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const T* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float top_p,
float* output_all_probs,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode)
{
invokeBatchTopPSampling(workspace,
workspace_size,
cub_temp_storage_size,
output_ids,
sequence_length,
finished_buf,
cum_log_probs,
output_log_probs,
log_probs,
id_vals,
offset_buf,
begin_offset_buf,
curandstate,
batch_size,
vocab_size_padded,
end_ids,
top_p,
nullptr,
output_all_probs,
stream,
cuda_device_prop,
skip_decode);
}
template void invokeTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const float* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float top_p,
float* output_all_probs,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template void invokeTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const half* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float top_p,
float* output_all_probs,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template<typename T>
__global__ void
addBiasSoftMax(T* logits, const T* bias, const int* end_ids, const bool* finished, const int n_padded, const int n)
{
int bid = blockIdx.x;
bool finish = (finished != nullptr) ? finished[bid] : false;
int offset = bid * n_padded;
float max_val = -1 * FLT_MAX;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
__shared__ float s_max_val;
__shared__ float s_sum_val;
for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
if (tid < n) {
if (finish) {
logits[offset + tid] = (tid == end_ids[bid]) ? static_cast<T>(MAX_T_VAL) : static_cast<T>(-MAX_T_VAL);
}
else {
T bias_val = (bias != nullptr) ? bias[tid] : static_cast<T>(0.0f);
logits[offset + tid] += bias_val;
}
}
else {
logits[offset + tid] = static_cast<T>(-MAX_T_VAL);
}
max_val = max(max_val, (float)logits[offset + tid]);
}
max_val = blockReduceMax<float>((float)max_val);
if (threadIdx.x == 0) {
s_max_val = max_val;
}
__syncthreads();
float sum_val = 0.0f;
for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val);
sum_val += (float)logits[offset + tid];
}
sum_val = blockReduceSum<float>(sum_val);
if (threadIdx.x == 0) {
s_sum_val = sum_val;
}
__syncthreads();
for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
logits[offset + tid] = ((float)logits[offset + tid] / (s_sum_val + 1e-6f));
}
}
template<typename T>
void invokeAddBiasSoftMax(T* logits,
const T* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream)
{
dim3 grid(m);
dim3 block(min(n, 1024));
/*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */
addBiasSoftMax<<<grid, block, 0, stream>>>(logits, bias, end_ids, finished, n_padded, n);
}
template void invokeAddBiasSoftMax(float* logits,
const float* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream);
template void invokeAddBiasSoftMax(half* logits,
const half* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream);
template void invokeAddBiasSoftMax(__nv_bfloat16* logits,
const __nv_bfloat16* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream);
__global__ void computeToppDecay(float* runtime_top_p,
const float* runtime_initial_top_p,
const int* output_ids,
const float* top_p_decay,
const float* top_p_min,
const int32_t* top_p_reset_ids,
const int local_batch_size)
{
/**
* @brief Compute the topp decay by https://arxiv.org/pdf/2206.04624.pdf
* In short, the formula is
* runtime_top_p = max(runtime_top_p * top_p_decay, top_p_min)
* If generating the top_p_reset_ids, then reset the runtime_top_p.
*
* \param runtime_top_p [local_batch_size]
* \param runtime_initial_top_p [local_batch_size]
* \param output_ids [local_batch_size]
* \param top_p_decay [local_batch_size]
* \param top_p_min [local_batch_size]
* \param top_p_reset_ids [local_batch_size]
* \param local_batch_size
*
*/
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (output_ids[idx] == top_p_reset_ids[idx]) {
runtime_top_p[idx] = runtime_initial_top_p[idx];
}
else {
runtime_top_p[idx] = max(runtime_top_p[idx] * top_p_decay[idx], top_p_min[idx]);
}
}
void invokeComputeToppDecay(float* runtime_top_p,
const float* runtime_initial_top_p,
const int* output_ids,
const float* top_p_decay,
const float* top_p_min,
const int32_t* top_p_reset_ids,
const int local_batch_size,
cudaStream_t stream)
{
dim3 block(min(local_batch_size, 512));
dim3 grid((local_batch_size + block.x - 1) / block.x);
computeToppDecay<<<grid, block, 0, stream>>>(
runtime_top_p, runtime_initial_top_p, output_ids, top_p_decay, top_p_min, top_p_reset_ids, local_batch_size);
}
static __global__ void set_topp_runtime_args(int batch_size,
uint top_k,
uint* top_ks,
int top_ks_size,
float top_p,
float* top_ps,
int top_ps_size,
bool* skip_decode,
float* initial_top_p_buf,
float* top_p_decay_buf,
const float* top_p_decay,
float* top_p_min_buf,
const float* top_p_min,
int32_t* top_p_reset_ids_buf,
const uint32_t* top_p_reset_ids)
{
/**
* @brief Setup the runtime arguments for topp, broadcasting top_p to top_ps
and top_k to top_ks, copying top_p_decay/top_p_min/top_p_reset_ids
to internal buffers.
*
* \param batch_size [batch_size]
* \param op_k [batch_size]
* \param top_ks [batch_size]
* \param top_ks_size [batch_size]
* \param top_p [batch_size]
* \param top_ps [batch_size]
* \param top_ps_size [batch_size]
* \param skip_decode [batch_size]
* \param initial_top_p_buf [batch_size]
* \param top_p_decay_buf [batch_size]
* \param top_p_decay [batch_size], optional, must between [0, 1]
* \param top_p_min_buf [batch_size]
* \param top_p_min [batch_size], optional, must between [0, 1]
* \param top_p_reset_ids_buf [batch_size]
* \param top_p_reset_ids [batch_size], optional
*
*/
int index = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = index; i < batch_size; i += gridDim.x * blockDim.x) {
uint k = top_ks_size > 1 ? top_ks[i] : top_k;
float p = top_ps_size > 1 ? top_ps[i] : top_p;
if (k == 0 && p == 0.0f) {
// FT's topp implementation does not support topp = 0.0f, but it equivalent to greedy search.
// So, we set the topk = 1 as an alternative solution.
k = 1;
}
top_ks[i] = k;
// Clip p value if it is out of range. range = [0.0, 1.0].
top_ps[i] = p < 0.0f ? 0.0f : (p > 1.0f ? 1.0f : p);
if (p < 0.0f || p > 1.0f) {
printf("[WARNING] topp (%f) is out of range ([0.0, 1.0f]) for token %d"
" clip to closest number %f.\n",
p,
i,
top_ps[i]);
}
skip_decode[i] = k > 0;
if (initial_top_p_buf) {
initial_top_p_buf[i] = top_ps[i];
}
if (top_p_decay_buf) {
top_p_decay_buf[i] = top_p_decay == nullptr ? 1.0f : top_p_decay[i];
if (top_p_decay_buf[i] > 1.0f || top_p_decay_buf[i] <= 0.0f) {
printf("[WARNING] top_p_decay_buf (%f) is out of range ([0.0, 1.0f]) for token %d,"
" change to 1.0f.\n",
top_p_decay_buf[i],
i);
top_p_decay_buf[i] = 1.0f;
}
top_p_min_buf[i] = top_p_min == nullptr ? 1e-6f : top_p_min[i]; // prevent topp becoming 0.0
if (top_p_min_buf[i] > 1.0f || top_p_min_buf[i] <= 0.0f) {
printf("[WARNING] top_p_min_buf (%f) is out of range ([0.0, 1.0f]) for token %d,"
" change to 0.5f.\n",
top_p_min_buf[i],
i);
top_p_min_buf[i] = 0.5f;
}
if (top_p_reset_ids == nullptr) {
top_p_reset_ids_buf[i] = -1;
} else {
top_p_reset_ids_buf[i] = top_p_reset_ids[i];
}
// top_p_reset_ids_buf[i] = (int32_t)(top_p_reset_ids == nullptr ? -1 : top_p_reset_ids[i]);
}
}
}
void invokeSetupTopPRuntimeArgs(int batch_size,
uint top_k,
uint* top_ks,
int top_ks_size,
float top_p,
float* top_ps,
int top_ps_size,
bool* skip_decode,
float* initial_top_p_buf,
float* top_p_decay_buf,
const float* top_p_decay,
float* top_p_min_buf,
const float* top_p_min,
int32_t* top_p_reset_ids_buf,
const uint32_t* top_p_reset_ids,
cudaStream_t stream)
{
dim3 block(std::min((int)batch_size, 256));
dim3 grid(div_up((int)batch_size, (int)block.x));
set_topp_runtime_args<<<grid, block, 0, stream>>>(batch_size,
top_k,
top_ks,
top_ks_size,
top_p,
top_ps,
top_ps_size,
skip_decode,
initial_top_p_buf,
top_p_decay_buf,
top_p_decay,
top_p_min_buf,
top_p_min,
top_p_reset_ids_buf,
top_p_reset_ids);
}
} // namespace rtp_llm