in maga_transformer/cpp/devices/arm_impl/ArmSampleOp.cc [206:358]
GreedyOutput ArmCpuDevice::sampleGreedy(const GreedyParams& params) {
const auto& logits = params.logits;
const auto batch_size = logits.shape()[0];
RUNTIME_ASSERT_OP_ARG(batch_size < init_params_.max_batch_size,
"batch_size exceeded device limit %ld: %ld",
init_params_.max_batch_size, batch_size);
const auto vocab_size_padded = logits.shape()[1];
const auto step = params.step;
RUNTIME_ASSERT_OP_ARG(batch_size == params.token_ids.shape()[0],
"logits.shape[0] should equal to token_ids.shape[0], but %ld vs %ld",
batch_size, params.token_ids.shape()[0]);
RUNTIME_ASSERT_OP_ARG((step == params.token_ids.shape()[1] - 1),
"step should equal to token_ids.shape[1] - 1, but %ld vs %ld",
step, params.token_ids.shape()[1] - 1);
auto& tokens = params.token_ids;
// auto transposed_tokens = transpose({*device_tokens});
// 1. prepare buffers
auto& top_k = params.top_k;
auto& top_p = params.top_p;
auto& temperature = params.temperature;
auto& random_seed = params.random_seed;
RTP_LLM_CHECK(top_k.size() == batch_size);
RTP_LLM_CHECK(top_p.size() == batch_size);
RTP_LLM_CHECK(temperature.size() == batch_size);
auto default_top_k = top_k.data<uint32_t>()[0];
auto default_top_p = top_p.data<float>()[0];
if (default_top_k == 0) {
default_top_k = 1;
}
auto max_top_k = *std::max_element(top_k.data<uint32_t>(), top_k.dataWithOffset<uint32_t>(top_k.size()));
if (max_top_k == 0) {
// for safety. TopKSamplingLayer handles a case of top_k=0 and top_p=0 as
// a greedy decode, i.e. top_k=1, although such case has max_top_k=0.
max_top_k = 1;
}
auto max_top_p = *std::max_element(top_p.data<float>(), top_p.dataWithOffset<float>(top_p.size()));
RTP_LLM_LOG_DEBUG("max_top_k: %d, max_top_p: %f", max_top_k, max_top_p);
auto skip_top_k_decode_buf = allocateBuffer({DataType::TYPE_BOOL, {batch_size}});
auto topk_tmp_val_buf = allocateBuffer({DataType::TYPE_FP32, {batch_size * max_top_k}});
auto topk_tmp_id_buf = allocateBuffer({DataType::TYPE_INT32, {batch_size * max_top_k}});
// std::mt19937 generator(seed);
std::vector<std::mt19937> generator_lists(batch_size);
unsigned one_seed = std::random_device{}();
for (size_t i = 0; i < batch_size; i++) {
generator_lists[i] = std::mt19937(one_seed + i);
}
if (random_seed) {
auto& seeds = random_seed.value().get();
uint64_t* seedsPtr = seeds.data<uint64_t>();
if (seeds.size() == 1) {
for (int i = 0; i < batch_size; i++) {
generator_lists[i] = std::mt19937(seedsPtr[0]);
}
} else {
RUNTIME_ASSERT_OP_ARG((seeds.size() == batch_size),
"random_seed.size() should equal to batch_size, but %ld vs %ld",
seeds.size(), batch_size);
for (int i = 0; i < batch_size; i++) {
generator_lists[i] = std::mt19937(seedsPtr[i]);
}
}
}
// 3.2. compute logits penalty
if (std::any_of(temperature.data<float>(),
temperature.data<float>() + batch_size,
[&](auto t) { return t != 1.0f; }))
{
temperaturePenalty(
logits.data<float>(),
temperature.data<float>(),
batch_size,
vocab_size_padded
);
}
const auto decoder_batch_size = params.sequence_lengths.shape()[0];
if (decoder_batch_size) {
if (step > 1 && params.repetition_penalty && decoder_batch_size) {
auto& repetition_penalty = params.repetition_penalty->get();
if (std::any_of(repetition_penalty.data<float>(),
repetition_penalty.data<float>() + batch_size,
[&](auto t) { return t != 1.0f; }))
{
// const auto repetition_penalty_type = RepetitionPenaltyType::Multiplicative;
repetitionPenalty(logits.data<float>(),
repetition_penalty.data<float>(),
tokens.data<int32_t>(),
batch_size,
vocab_size_padded,
params.sequence_lengths.data<int32_t>(),
step + 1,
step);
}
}
if (params.min_lengths.has_value())
if (params.min_lengths && params.eos_ids) {
minLengthPenaltyNew(logits.data<float>(),
params.min_lengths.value().get().data<int32_t>(),
params.eos_ids.value().get().data<int32_t>(),
params.sequence_lengths.data<int32_t>(),
params.input_lengths.data<int32_t>(),
batch_size,
vocab_size_padded);
}
}
// 4. run sampling
// 4.1 run top_k
setup_topk_runtime_args(batch_size,
default_top_k,
top_k.data<uint32_t>(),
batch_size,
default_top_p,
top_p.data<float>(),
batch_size,
skip_top_k_decode_buf->data<bool>());
float* topk_logs = topk_tmp_val_buf->data<float>();
int* topk_logs_indices = topk_tmp_id_buf->data<int>();
topKKernel(topk_logs,
topk_logs_indices,
logits.data<float>(), batch_size,
vocab_size_padded,
max_top_k,
top_k.data<uint32_t>());
topk_sampling(batch_size, topk_logs_indices, topk_logs, generator_lists,
tokens.data<int32_t>(),
nullptr, // sequence_length
nullptr, // finished
nullptr, // cum_log_probs,
nullptr, // output_log_probs,
nullptr, // token_id_for_index_prob,
max_top_k,
top_k.data<uint32_t>(),
1.0f,
top_p.data<float>(),
params.eos_ids.value().get().data<int32_t>(),
vocab_size_padded,
skip_top_k_decode_buf->data<bool>(),
step + 1);
return GreedyOutput{};
}