maga_transformer/cpp/devices/cuda_impl/CudaSampleOp.cc (418 lines of code) (raw):
#include "maga_transformer/cpp/devices/cuda_impl/CudaDevice.h"
#include "maga_transformer/cpp//core/BufferHelper.h"
#include "maga_transformer/cpp/devices/CommonDefines.h"
#include "maga_transformer/cpp/kernels/sampling_topk_kernels.h"
#include "maga_transformer/cpp/kernels/sampling_topp_kernels.h"
#include "maga_transformer/cpp/kernels/sampling_penalty_kernels.h"
#include "maga_transformer/cpp/kernels/banRepeatNgram.h"
#include "maga_transformer/cpp/cuda/memory_utils.h"
#include "maga_transformer/cpp/devices/utils/DebugUtils.h"
#include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h"
#include "3rdparty/flashinfer/flashinfer.h"
#include <cstddef>
#include <random>
#include <memory>
using namespace std;
namespace rtp_llm {
using SamplerT = float;
void CudaDevice::processLogits(const GreedyParams& params, const BufferPtr &device_tokens, const BufferPtr &transposed_tokens) {
auto &logits = params.logits;
const auto vocab_size_padded = params.logits.shape()[1];
const auto decoder_batch_size = params.sequence_lengths.shape()[0];
const auto batch_size = logits.shape()[0];
const auto step = params.step;
if (std::any_of(params.temperature.data<float>(),
params.temperature.data<float>() + batch_size,
[&](auto t) { return t != 1.0f; }))
{
BufferPtr temperature_buf = allocateBuffer({DataType::TYPE_FP32, {batch_size}});
copy({*temperature_buf, params.temperature});
invokeBatchApplyTemperaturePenalty(
logits.data<float>(),
(float *)nullptr, // embedding_bias
temperature_buf->data<float>(),
batch_size,
vocab_size_padded,
vocab_size_padded,
stream_);
}
if (params.repetition_penalty) {
auto& repetition_penalty = params.repetition_penalty->get();
if (std::any_of(repetition_penalty.data<float>(),
repetition_penalty.data<float>() + batch_size,
[&](auto t) { return t != 1.0f; }))
{
auto sequence_lengths = clone({params.input_lengths});
if (decoder_batch_size) {
copy({sequence_lengths->view(0, decoder_batch_size), params.sequence_lengths});
}
const auto repetition_penalty_type = RepetitionPenaltyType::Multiplicative;
auto repetition_penalty_buf = allocateBuffer({DataType::TYPE_FP32, {batch_size}});
auto penalty_logits = allocateBuffer({DataType::TYPE_FP32, {batch_size * 64 * 1024}});
copy({*repetition_penalty_buf, repetition_penalty});
invokeBatchApplyRepetitionPenalty(
logits.data<float>(),
penalty_logits->data<float>(),
repetition_penalty_buf->data<float>(),
transposed_tokens->data<int32_t>(),
batch_size,
batch_size, // local_batch_size
vocab_size_padded,
sequence_lengths->data<int32_t>(),
step + 1, // max_input_length
step + 1, // step
repetition_penalty_type,
stream_);
// NOTE: here step is max_len - 1
}
}
/*
logits: [decoder_batch_size;context_batch_size]
input_lengths: [decoder_batch_size;context_batch_size]
sequence_lengths: [decoder_batch_size]
*/
if (params.min_lengths && params.eos_ids) {
auto min_lengths_buf = clone({params.min_lengths.value().get()});
auto sequence_lengths = clone({params.sequence_lengths});
auto input_lengths = clone({params.input_lengths});
invokeMinLengthPenaltyNew(
logits.data<float>(),
min_lengths_buf->data<int32_t>(),
params.eos_ids.value().get().data<int32_t>(),
sequence_lengths->data<int32_t>(),
input_lengths->data<int32_t>(),
decoder_batch_size,
batch_size,
vocab_size_padded,
stream_);
}
if (decoder_batch_size && params.no_repeat_ngram_size) {
const auto& no_repeat_ngram_size = params.no_repeat_ngram_size.value().get();
if (any_of(no_repeat_ngram_size.data<int32_t>(),
no_repeat_ngram_size.data<int32_t>() + decoder_batch_size,
[](auto s) { return s != 0; }))
{
auto no_repeat_ngram_size_buf = clone({no_repeat_ngram_size});
auto output_ids_ptrs = allocateBuffer({DataType::TYPE_UINT64, {decoder_batch_size}, AllocationType::HOST});
for (int i = 0; i < decoder_batch_size; i++) {
output_ids_ptrs->data<uint64_t>()[i] = (uint64_t)(device_tokens->data<int32_t>() + i * (step + 1));
}
auto output_ids_ptrs_device = clone({*output_ids_ptrs, AllocationType::DEVICE});
auto sequence_lengths = clone({params.sequence_lengths});
tensorrt_llm::kernels::invokeBanRepeatNgram(
logits.data<float>(),
(int32_t const**)(output_ids_ptrs_device->data()),
nullptr, // finished_buf
nullptr, // parent_ids_buf
nullptr, // batch_slot
sequence_lengths->data<int32_t>(),
decoder_batch_size,
1, // beam_width
step + 1,
no_repeat_ngram_size_buf->data<int32_t>(),
vocab_size_padded,
step + 1,
stream_);
}
}
}
bool CudaDevice::checkUseFlashinferSampleGreedy(const GreedyParams& params) {
if ((!use_flashinfer_sample_kernel) ||
params.random_seed.has_value() ||
params.cum_log_probs.has_value() ||
params.output_log_probs.has_value()) {
return false;
}
return true;
}
GreedyOutput CudaDevice::flashinferSampleGreedy(const GreedyParams& params, const BufferPtr &transposed_tokens) {
const auto batch_size = params.logits.shape()[0];
auto& top_k = params.top_k;
auto& top_p = params.top_p;
auto logits_ref = params.logits.slice(0, params.logits.shape()[0]);
auto probs = softmax({logits_ref, std::nullopt, std::nullopt, 1.0f, DataType::TYPE_INVALID, std::nullopt});
BufferPtr success = allocateBuffer({DataType::TYPE_BOOL, {batch_size}});
auto samples = transposed_tokens->view(transposed_tokens->shape()[0] - 1, 1);
torch::TensorOptions options = torch::TensorOptions(dataTypeToTorchType(probs->type())).device(torch::Device(torch::kCUDA));
bool deterministic = true;
if (!std::getenv("SAMPLE_TEST")) {
std::random_device rd;
std::mt19937_64 gen(rd());
std::uniform_int_distribution<std::int64_t> distrib(0, std::numeric_limits<std::int64_t>::max());
torch::manual_seed(distrib(gen));
deterministic = false;
}
bool need_output_all_probs = params.output_all_probs.has_value();
auto uniform_samples = torch::rand({32, (int)batch_size}, options);
torch::Tensor probs_t = Buffer2torchTensor(probs, false);
torch::Tensor samples_t = Buffer2torchTensor(samples, false);
torch::Tensor success_t = Buffer2torchTensor(success, false);
torch::Tensor top_k_t = Buffer2torchTensor(top_k, false);
torch::Tensor top_p_t = Buffer2torchTensor(top_p, false);
torch::Tensor output_all_probs_t;
if (need_output_all_probs) {
output_all_probs_t = Buffer2torchTensor(params.output_all_probs.value().get(), false);
}
std::transform(top_p.data<float>(), top_p.data<float>() + batch_size, top_p.data<float>(), [&](auto t) { return std::abs(t) < 1e-7 ? 1.0 : t;});
if (std::all_of(top_k.data<uint32_t>(),
top_k.data<uint32_t>() + batch_size,
[&](auto t) { return t == 1; })) {
torch::Tensor selected_tokens = torch::argmax(probs_t, -1, /*keepdim=*/false);
samples_t.copy_(selected_tokens);
success.reset();
if (need_output_all_probs) {
top_k_renorm_probs(probs_t, output_all_probs_t, top_k_t, 0, (int64_t)stream_);
}
} else if (std::all_of(top_k.data<uint32_t>(),
top_k.data<uint32_t>() + batch_size,
[&](auto t) { return t <= 0; })) {
top_p_sampling_from_probs(probs_t, uniform_samples, samples_t, success_t, top_p_t, 1.0, deterministic, (int64_t)stream_);
if (need_output_all_probs) {
top_p_renorm_probs(probs_t, output_all_probs_t, top_p_t, 1.0, (int64_t)stream_);
}
}
else if (std::all_of(top_p.data<float>(),
top_p.data<float>() + batch_size,
[&](auto t) { return std::abs(t - 1.0f) < 1e-7; })) {
std::transform(top_k.data<uint32_t>(), top_k.data<uint32_t>() + batch_size, top_k.data<uint32_t>(), [&](auto t) { return t <= 0 ? 1 << 30 : t;});
top_k_sampling_from_probs(probs_t, uniform_samples, samples_t, success_t, top_k_t, 0, deterministic, (int64_t)stream_);
if (need_output_all_probs) {
top_k_renorm_probs(probs_t, output_all_probs_t, top_k_t, 0, (int64_t)stream_);
}
} else {
std::transform(top_k.data<uint32_t>(), top_k.data<uint32_t>() + batch_size, top_k.data<uint32_t>(), [&](auto t) { return t <= 0 ? 1 << 30 : t;});
top_k_top_p_sampling_from_probs(probs_t, uniform_samples, samples_t, success_t, top_k_t, 1.0, top_p_t, 1.0, deterministic, (int64_t)stream_);
if (need_output_all_probs) {
torch::Tensor temp_t = torch::zeros_like(output_all_probs_t);
top_k_renorm_probs(probs_t, temp_t, top_k_t, 1.0, (int64_t)stream_);
top_p_renorm_probs(temp_t, output_all_probs_t, top_p_t, 1.0, (int64_t)stream_);
}
}
auto output_tokens = transpose({*transposed_tokens});
copy({params.token_ids, *output_tokens});
sync_check_cuda_error();
return {success};
}
GreedyOutput CudaDevice::sampleGreedy(const GreedyParams& params) {
auto device_tokens = clone({params.token_ids});
auto transposed_tokens = transpose({*device_tokens});
processLogits(params, device_tokens, transposed_tokens);
// fast path for topk = 1
const auto batch_size = params.logits.shape()[0];
auto& top_k = params.top_k;
if (std::all_of(top_k.data<uint32_t>(),
top_k.data<uint32_t>() + batch_size,
[&](auto t) { return t == 1; }) && !params.output_all_probs.has_value()) {
BufferPtr logits_ref = params.logits.slice(0, params.logits.shape()[0]);
Buffer samples = transposed_tokens->view(transposed_tokens->shape()[0] - 1, 1);
torch::Tensor samples_t = Buffer2torchTensor(samples, false);
torch::Tensor probs_t = Buffer2torchTensor(*logits_ref, false);
torch::Tensor selected_tokens = torch::argmax(probs_t, -1, /*keepdim=*/false);
samples_t.copy_(selected_tokens);
auto output_tokens = transpose({*transposed_tokens});
copy({params.token_ids, *output_tokens});
return GreedyOutput{};
}
if (checkUseFlashinferSampleGreedy(params)) {
return flashinferSampleGreedy(params, transposed_tokens);
}
completeSampleGreedy(params, transposed_tokens);
return GreedyOutput{};
}
// batch sampling explained:
// topk = [4, 0, 4]. topp = [0.0, 0.5, 0.5]
// then topk_decode handles [4, x, 4 + 0.5]
// topp_decode handles [x, 0.5, x]
// where "x" are skipped.
// topk should has higher proirity than topp.
void CudaDevice::completeSampleGreedy(const GreedyParams& params, const BufferPtr &transposed_tokens) {
const auto& logits = params.logits;
const auto batch_size = logits.shape()[0];
RUNTIME_ASSERT_OP_ARG(batch_size < init_params_.max_batch_size,
"batch_size exceeded device limit %ld: %ld",
init_params_.max_batch_size, batch_size);
const auto vocab_size_padded = logits.shape()[1];
const auto step = params.step;
RUNTIME_ASSERT_OP_ARG(batch_size == params.token_ids.shape()[0],
"logits.shape[0] should equal to token_ids.shape[0], but %ld vs %ld",
batch_size, params.token_ids.shape()[0]);
RUNTIME_ASSERT_OP_ARG((step == params.token_ids.shape()[1] - 1),
"step should equal to token_ids.shape[1] - 1, but %ld vs %ld",
step, params.token_ids.shape()[1] - 1);
// 1. prepare buffers
auto& top_k = params.top_k;
auto& top_p = params.top_p;
auto& temperature = params.temperature;
auto& random_seed = params.random_seed;
RTP_LLM_CHECK(top_k.size() == batch_size);
RTP_LLM_CHECK(top_p.size() == batch_size);
RTP_LLM_CHECK(temperature.size() == batch_size);
auto default_top_k = top_k.data<uint32_t>()[0];
auto default_top_p = top_p.data<float>()[0];
auto max_top_k = *max_element(top_k.data<uint32_t>(), top_k.dataWithOffset<uint32_t>(top_k.size()));
if (max_top_k == 0) {
// for safety. TopKSamplingLayer handles a case of top_k=0 and top_p=0 as
// a greedy decode, i.e. top_k=1, although such case has max_top_k=0.
max_top_k = 1;
}
auto max_top_p = *max_element(top_p.data<SamplerT>(), top_p.dataWithOffset<SamplerT>(top_p.size()));
RTP_LLM_LOG_DEBUG("max_top_k: %d, max_top_p: %f", max_top_k, max_top_p);
// see BaseSamplingLayer<T>::allocateBuffer ------------------
auto skip_top_k_decode_buf = allocateBuffer({DataType::TYPE_BOOL, {batch_size}});
auto skip_top_p_decode_buf = allocateBuffer({DataType::TYPE_BOOL, {batch_size}});
auto topp_id_vals_buf = allocateBuffer({DataType::TYPE_INT32, {batch_size * vocab_size_padded}});
auto topp_offset_buf = allocateBuffer({DataType::TYPE_INT32, {batch_size + 1}});
auto begin_topp_offset_buf = allocateBuffer({DataType::TYPE_INT32, {batch_size + 1}});
auto runtime_top_k_buf = allocateBuffer({DataType::TYPE_UINT32, {batch_size}});
copy({*runtime_top_k_buf, top_k});
auto runtime_top_p_buf = allocateBuffer({DataType::TYPE_FP32, {batch_size}});
copy({*runtime_top_p_buf, top_p});
auto cum_log_probs = GET_TYPED_VALUE_FROM_OPT_REF(params.cum_log_probs, float);
auto output_log_probs = GET_TYPED_VALUE_FROM_OPT_REF(params.output_log_probs, float);
auto output_all_probs = GET_TYPED_VALUE_FROM_OPT_REF(params.output_all_probs, float);
if (random_seed) {
auto& seeds = random_seed.value().get();
if (seeds.size() == 1) {
invokeCurandInitialize(
(curandState_t *)curandstate_buf_->data(), batch_size,
seeds.data<uint64_t>()[0], stream_);
} else {
auto random_seeds_buf = allocateBuffer({DataType::TYPE_UINT64, {batch_size}});
RUNTIME_ASSERT_OP_ARG((seeds.size() == batch_size),
"random_seed.size() should equal to batch_size, but %ld vs %ld",
seeds.size(), batch_size);
copy({*random_seeds_buf, seeds});
invokeCurandBatchInitialize(
(curandState_t *)curandstate_buf_->data(), batch_size,
(unsigned long long *)random_seeds_buf->data(), stream_);
}
}
// 4. run sampling
// 4.1 run top_k
invokeSetupTopKRuntimeArgs(batch_size,
default_top_k,
runtime_top_k_buf->data<uint>(),
batch_size,
default_top_p,
runtime_top_p_buf->data<float>(),
batch_size,
skip_top_k_decode_buf->data<bool>(),
stream_);
auto skip_top_k = clone({*skip_top_k_decode_buf, AllocationType::HOST});
if (std::any_of(skip_top_k->data<bool>(),
skip_top_k->dataWithOffset<bool>(batch_size),
[](auto s) { return !s; }))
{
size_t topk_ws_size;
invokeTopKSampling<SamplerT>(nullptr, // workspace3
topk_ws_size,
nullptr, // log_probs
nullptr, // ids
nullptr, // sequence_length
nullptr, // finished_buf
nullptr, /// cum_log_probs
nullptr, // output_log_probs
nullptr, // curandstaste
max_top_k,
max_top_p,
vocab_size_padded,
nullptr, // end ids
nullptr,
stream_,
batch_size,
nullptr);
auto top_k_workspace = allocateBuffer({topk_ws_size});
invokeBatchTopKSampling(
top_k_workspace->data(),
topk_ws_size,
logits.data<float>(),
transposed_tokens->dataWithOffset<int32_t>(step * batch_size),
nullptr, // sequence_length
nullptr, // finished
cum_log_probs,
output_log_probs,
(curandState_t *)curandstate_buf_->data(),
max_top_k, // useless because runtime_top_k_buf_ is never nullptr. Keep for legacy.
(int32_t*)runtime_top_k_buf->data<uint32_t>(),
1.0f, // useless because runtime_top_p_buf_ is never nullptr. Keep for legacy.
runtime_top_p_buf->data<float>(),
vocab_size_padded,
nullptr, // end_id
output_all_probs,
stream_,
batch_size,
skip_top_k_decode_buf->data<bool>());
}
// 4.2. run top_p
// NOTE: running top_k could write values to runtime bufs, so need to copy again.
copy({*runtime_top_k_buf, top_k});
copy({*runtime_top_p_buf, top_p});
invokeSetupTopPRuntimeArgs(batch_size,
default_top_k,
runtime_top_k_buf->data<uint>(),
batch_size,
default_top_p,
runtime_top_p_buf->data<float>(),
batch_size,
skip_top_p_decode_buf->data<bool>(),
nullptr, // initial_top_p_buf,
nullptr, // top_p_decay_buf,
nullptr,
nullptr, // top_p_min_buf,
nullptr,
nullptr, // top_p_reset_ids_buf,
nullptr,
stream_);
auto skip_top_p = clone({*skip_top_p_decode_buf, AllocationType::HOST});
if (std::any_of(skip_top_p->data<bool>(),
skip_top_p->dataWithOffset<bool>(batch_size),
[](auto s) { return !s; }))
{
invokeTopPInitialize(
topp_id_vals_buf->data<int32_t>(),
topp_offset_buf->data<int32_t>(),
begin_topp_offset_buf->data<int32_t>(),
batch_size,
vocab_size_padded,
stream_);
invokeAddBiasSoftMax(
logits.data<SamplerT>(),
(SamplerT *)nullptr, // bias
nullptr, // end_id
nullptr, // finished
batch_size,
vocab_size_padded,
vocab_size_padded,
stream_);
size_t topp_ws_size;
size_t cub_temp_storage_size;
invokeTopPSampling<SamplerT>(nullptr, // workspace
topp_ws_size,
cub_temp_storage_size,
nullptr, // output_ids
nullptr, // sequence_length
nullptr, // finished_buf
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs
nullptr, // id_vals
nullptr, // offsets_buf
nullptr, // begin_offset_buf
nullptr, // curandstate
batch_size,
vocab_size_padded,
nullptr,
max_top_p,
nullptr, // output_all_probs
stream_,
&device_prop_,
nullptr);
auto top_p_workspace = allocateBuffer({topp_ws_size});
invokeBatchTopPSampling(
top_p_workspace->data(),
topp_ws_size,
cub_temp_storage_size,
transposed_tokens->dataWithOffset<int32_t>(step * batch_size),
nullptr, // sequence_length
nullptr, // finished
cum_log_probs,
output_log_probs,
logits.data<float>(),
topp_id_vals_buf->data<int32_t>(),
topp_offset_buf->data<int32_t>(),
begin_topp_offset_buf->data<int32_t>(),
(curandState_t *)curandstate_buf_->data(),
batch_size,
vocab_size_padded,
nullptr, // end_id
max_top_p,
runtime_top_p_buf->data<float>(),
output_all_probs,
stream_,
&device_prop_,
skip_top_p_decode_buf->data<bool>());
}
auto output_tokens = transpose({*transposed_tokens});
copy({params.token_ids, *output_tokens});
sync_check_cuda_error();
}
} // namespace rtp_llm