in maga_transformer/cpp/devices/cpu_impl/CpuSampleOp.cc [403:556]
GreedyOutput CpuDevice::sampleGreedy(const GreedyParams& params) {
const auto& logits = params.logits;
const auto batch_size = logits.shape()[0];
const auto vocab_size_padded = logits.shape()[1];
const auto step = params.step;
auto& token_ids = params.token_ids;
RUNTIME_ASSERT_OP_ARG(batch_size == params.token_ids.shape()[0],
"logits.shape[0] should equal to token_ids.shape[0], but %ld vs %ld",
batch_size,
params.token_ids.shape()[0]);
RUNTIME_ASSERT_OP_ARG((step == params.token_ids.shape()[1] - 1),
"step should equal to token_ids.shape[1] - 1, but %ld vs %ld",
step,
params.token_ids.shape()[1] - 1);
// 1. prepare
auto& top_k = params.top_k;
auto& top_p = params.top_p;
auto& temperature = params.temperature;
auto& random_seed = params.random_seed;
auto& cum_log_prob = params.cum_log_probs;
auto default_top_k = top_k.data<uint32_t>()[0];
auto default_top_p = top_p.data<float>()[0];
auto max_top_k = *std::max_element(top_k.data<uint32_t>(), top_k.dataWithOffset<uint32_t>(top_k.size()));
if (max_top_k == 0) {
max_top_k = 1;
}
auto max_top_p = *std::max_element(top_p.data<float>(), top_p.dataWithOffset<float>(top_p.size()));
bool* skip_top_k_decode = static_cast<bool*>(aligned_alloc(64, batch_size * sizeof(bool)));
bool* skip_top_p_decode = static_cast<bool*>(aligned_alloc(64, batch_size * sizeof(bool)));
uint32_t* runtime_top_k = static_cast<uint32_t*>(aligned_alloc(64, batch_size * sizeof(uint32_t)));
std::memcpy(runtime_top_k, top_k.data(), batch_size * sizeof(uint32_t));
float* runtime_top_p = static_cast<float*>(aligned_alloc(64, batch_size * sizeof(float)));
std::memcpy(runtime_top_p, top_p.data(), batch_size * sizeof(float));
auto cum_log_probs = cum_log_prob.has_value() ? params.cum_log_probs.value().get().data<float>() : nullptr;
auto output_log_probs =
params.output_log_probs.has_value() ? params.output_log_probs.value().get().data<float>() : nullptr;
// 3.1 setup random seeds
auto seeds = random_seed.has_value() ? random_seed.value().get().data<uint64_t>() : nullptr;
float* rand_nums = static_cast<float*>(aligned_alloc(64, batch_size * sizeof(float)));
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
std::default_random_engine generator;
for (int i = 0; i < batch_size; i++) {
if (seeds != nullptr) {
generator.seed(seeds[i]);
} else {
generator.seed(std::random_device{}());
}
rand_nums[i] = distribution(generator);
}
// 3.2. compute logits penalty
if (std::any_of(
temperature.data<float>(), temperature.data<float>() + batch_size, [&](auto t) { return t != 1.0f; })) {
applyTemperaturePenalty(logits.data<float>(),
temperature.data<float>(),
batch_size,
vocab_size_padded,
vocab_size_padded);
}
const auto decoder_batch_size = params.sequence_lengths.shape()[0];
if (decoder_batch_size) {
if (step > 1 && params.repetition_penalty && decoder_batch_size) {
auto& repetition_penalty = params.repetition_penalty->get();
if (std::any_of(repetition_penalty.data<float>(),
repetition_penalty.data<float>() + batch_size,
[&](auto t) { return t != 1.0f; })) {
repetitionPenalty(logits.data<float>(),
repetition_penalty.data<float>(),
token_ids.data<int32_t>(),
batch_size,
vocab_size_padded,
params.sequence_lengths.data<int32_t>(),
step + 1,
step);
}
}
if (params.min_lengths.has_value())
if (params.min_lengths && params.eos_ids) {
minLengthPenalty(logits.data<float>(),
params.min_lengths.value().get().data<int32_t>(),
params.eos_ids.value().get().data<int32_t>(),
params.sequence_lengths.data<int32_t>(),
params.input_lengths.data<int32_t>(),
batch_size,
vocab_size_padded);
}
}
// 4. run sampling
// 4.1 run top_k
setup_topk(batch_size,
default_top_k,
runtime_top_k,
batch_size,
default_top_p,
runtime_top_p,
batch_size,
skip_top_k_decode);
if (std::any_of(skip_top_k_decode, skip_top_k_decode + batch_size, [](auto s) { return !s; })) {
batchTopKSampling(logits.data<float>(),
token_ids.data<int>(),
step,
cum_log_probs,
output_log_probs,
max_top_k,
runtime_top_k, // top_ks,
vocab_size_padded,
batch_size,
skip_top_k_decode,
rand_nums);
}
// 4.2. run top_p
setup_topp(batch_size,
default_top_k,
runtime_top_k,
batch_size,
default_top_p,
runtime_top_p,
batch_size,
skip_top_p_decode);
for (int i = 0; i < batch_size; ++i) {
computeSoftMax(logits.data<float>() + i * vocab_size_padded, vocab_size_padded);
}
batchTopPSampling(token_ids.data<int>(),
cum_log_probs,
output_log_probs,
logits.data<float>(),
step,
batch_size,
vocab_size_padded,
max_top_p,
runtime_top_p,
skip_top_p_decode,
rand_nums);
return GreedyOutput{};
}