GreedyOutput CpuDevice::sampleGreedy()

in maga_transformer/cpp/devices/cpu_impl/CpuSampleOp.cc [403:556]


GreedyOutput CpuDevice::sampleGreedy(const GreedyParams& params) {
    const auto& logits            = params.logits;
    const auto  batch_size        = logits.shape()[0];
    const auto  vocab_size_padded = logits.shape()[1];
    const auto  step              = params.step;

    auto& token_ids = params.token_ids;
    RUNTIME_ASSERT_OP_ARG(batch_size == params.token_ids.shape()[0],
                          "logits.shape[0] should equal to token_ids.shape[0], but %ld vs %ld",
                          batch_size,
                          params.token_ids.shape()[0]);
    RUNTIME_ASSERT_OP_ARG((step == params.token_ids.shape()[1] - 1),
                          "step should equal to token_ids.shape[1] - 1, but %ld vs %ld",
                          step,
                          params.token_ids.shape()[1] - 1);

    // 1. prepare
    auto& top_k        = params.top_k;
    auto& top_p        = params.top_p;
    auto& temperature  = params.temperature;
    auto& random_seed  = params.random_seed;
    auto& cum_log_prob = params.cum_log_probs;

    auto default_top_k = top_k.data<uint32_t>()[0];
    auto default_top_p = top_p.data<float>()[0];

    auto max_top_k = *std::max_element(top_k.data<uint32_t>(), top_k.dataWithOffset<uint32_t>(top_k.size()));
    if (max_top_k == 0) {
        max_top_k = 1;
    }
    auto max_top_p = *std::max_element(top_p.data<float>(), top_p.dataWithOffset<float>(top_p.size()));

    bool* skip_top_k_decode = static_cast<bool*>(aligned_alloc(64, batch_size * sizeof(bool)));
    bool* skip_top_p_decode = static_cast<bool*>(aligned_alloc(64, batch_size * sizeof(bool)));

    uint32_t* runtime_top_k = static_cast<uint32_t*>(aligned_alloc(64, batch_size * sizeof(uint32_t)));
    std::memcpy(runtime_top_k, top_k.data(), batch_size * sizeof(uint32_t));

    float* runtime_top_p = static_cast<float*>(aligned_alloc(64, batch_size * sizeof(float)));
    std::memcpy(runtime_top_p, top_p.data(), batch_size * sizeof(float));

    auto cum_log_probs = cum_log_prob.has_value() ? params.cum_log_probs.value().get().data<float>() : nullptr;
    auto output_log_probs =
        params.output_log_probs.has_value() ? params.output_log_probs.value().get().data<float>() : nullptr;

    // 3.1 setup random seeds
    auto   seeds     = random_seed.has_value() ? random_seed.value().get().data<uint64_t>() : nullptr;
    float* rand_nums = static_cast<float*>(aligned_alloc(64, batch_size * sizeof(float)));

    std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
    std::default_random_engine            generator;

    for (int i = 0; i < batch_size; i++) {
        if (seeds != nullptr) {
            generator.seed(seeds[i]);
        } else {
            generator.seed(std::random_device{}());
        }
        rand_nums[i] = distribution(generator);
    }

    // 3.2. compute logits penalty
    if (std::any_of(
            temperature.data<float>(), temperature.data<float>() + batch_size, [&](auto t) { return t != 1.0f; })) {
        applyTemperaturePenalty(logits.data<float>(),
                                temperature.data<float>(),
                                batch_size,
                                vocab_size_padded,
                                vocab_size_padded);
    }

    const auto decoder_batch_size = params.sequence_lengths.shape()[0];

    if (decoder_batch_size) {
        if (step > 1 && params.repetition_penalty && decoder_batch_size) {
            auto& repetition_penalty = params.repetition_penalty->get();
            if (std::any_of(repetition_penalty.data<float>(),
                            repetition_penalty.data<float>() + batch_size,
                            [&](auto t) { return t != 1.0f; })) {
                repetitionPenalty(logits.data<float>(),
                                  repetition_penalty.data<float>(),
                                  token_ids.data<int32_t>(),
                                  batch_size,
                                  vocab_size_padded,
                                  params.sequence_lengths.data<int32_t>(),
                                  step + 1,
                                  step);
            }
        }

        if (params.min_lengths.has_value())
            if (params.min_lengths && params.eos_ids) {
                minLengthPenalty(logits.data<float>(),
                                 params.min_lengths.value().get().data<int32_t>(),
                                 params.eos_ids.value().get().data<int32_t>(),
                                 params.sequence_lengths.data<int32_t>(),
                                 params.input_lengths.data<int32_t>(),
                                 batch_size,
                                 vocab_size_padded);
            }
    }

    // 4. run sampling
    // 4.1 run top_k
    setup_topk(batch_size,
               default_top_k,
               runtime_top_k,
               batch_size,
               default_top_p,
               runtime_top_p,
               batch_size,
               skip_top_k_decode);

    if (std::any_of(skip_top_k_decode, skip_top_k_decode + batch_size, [](auto s) { return !s; })) {
        batchTopKSampling(logits.data<float>(),
                          token_ids.data<int>(),
                          step,
                          cum_log_probs,
                          output_log_probs,
                          max_top_k,
                          runtime_top_k,  // top_ks,
                          vocab_size_padded,
                          batch_size,
                          skip_top_k_decode,
                          rand_nums);
    }

    // 4.2. run top_p
    setup_topp(batch_size,
               default_top_k,
               runtime_top_k,
               batch_size,
               default_top_p,
               runtime_top_p,
               batch_size,
               skip_top_p_decode);

    for (int i = 0; i < batch_size; ++i) {
        computeSoftMax(logits.data<float>() + i * vocab_size_padded, vocab_size_padded);
    }

    batchTopPSampling(token_ids.data<int>(),
                      cum_log_probs,
                      output_log_probs,
                      logits.data<float>(),
                      step,
                      batch_size,
                      vocab_size_padded,
                      max_top_p,
                      runtime_top_p,
                      skip_top_p_decode,
                      rand_nums);
    return GreedyOutput{};
}