GreedyOutput ArmCpuDevice::sampleGreedy()

in maga_transformer/cpp/devices/arm_impl/ArmSampleOp.cc [206:358]


GreedyOutput ArmCpuDevice::sampleGreedy(const GreedyParams& params) {
    const auto& logits = params.logits;
    const auto batch_size = logits.shape()[0];
    RUNTIME_ASSERT_OP_ARG(batch_size < init_params_.max_batch_size,
                          "batch_size exceeded device limit %ld: %ld",
                          init_params_.max_batch_size, batch_size);
    const auto vocab_size_padded = logits.shape()[1];
    const auto step = params.step;
    RUNTIME_ASSERT_OP_ARG(batch_size == params.token_ids.shape()[0],
                          "logits.shape[0] should equal to token_ids.shape[0], but %ld vs %ld",
                          batch_size, params.token_ids.shape()[0]);
    RUNTIME_ASSERT_OP_ARG((step == params.token_ids.shape()[1] - 1),
                          "step should equal to token_ids.shape[1] - 1, but %ld vs %ld",
                          step, params.token_ids.shape()[1] - 1);
    auto& tokens = params.token_ids;
    // auto transposed_tokens = transpose({*device_tokens});

    // 1. prepare buffers
    auto& top_k = params.top_k;
    auto& top_p = params.top_p;
    auto& temperature = params.temperature;
    auto& random_seed = params.random_seed;
    RTP_LLM_CHECK(top_k.size() == batch_size);
    RTP_LLM_CHECK(top_p.size() == batch_size);
    RTP_LLM_CHECK(temperature.size() == batch_size);

    auto default_top_k = top_k.data<uint32_t>()[0];
    auto default_top_p = top_p.data<float>()[0];

    if (default_top_k == 0) {
        default_top_k = 1;
    }

    auto max_top_k = *std::max_element(top_k.data<uint32_t>(), top_k.dataWithOffset<uint32_t>(top_k.size()));
    if (max_top_k == 0) {
        // for safety. TopKSamplingLayer handles a case of top_k=0 and top_p=0 as
        // a greedy decode, i.e. top_k=1, although such case has max_top_k=0.
        max_top_k = 1;
    }
    auto max_top_p = *std::max_element(top_p.data<float>(), top_p.dataWithOffset<float>(top_p.size()));
    RTP_LLM_LOG_DEBUG("max_top_k: %d, max_top_p: %f", max_top_k, max_top_p);

    auto skip_top_k_decode_buf = allocateBuffer({DataType::TYPE_BOOL, {batch_size}});
    auto topk_tmp_val_buf = allocateBuffer({DataType::TYPE_FP32, {batch_size * max_top_k}});
    auto topk_tmp_id_buf = allocateBuffer({DataType::TYPE_INT32, {batch_size * max_top_k}});

    // std::mt19937 generator(seed);
    std::vector<std::mt19937> generator_lists(batch_size);
    unsigned one_seed = std::random_device{}();
    for (size_t i = 0; i < batch_size; i++) {
        generator_lists[i] = std::mt19937(one_seed + i);
    }
    if (random_seed) {
        auto& seeds = random_seed.value().get();
        uint64_t* seedsPtr = seeds.data<uint64_t>();
        if (seeds.size() == 1) {
            for (int i = 0; i < batch_size; i++) {
                generator_lists[i] = std::mt19937(seedsPtr[0]);
            }
        } else {
            RUNTIME_ASSERT_OP_ARG((seeds.size() == batch_size),
                                  "random_seed.size() should equal to batch_size, but %ld vs %ld",
                                  seeds.size(), batch_size);
            for (int i = 0; i < batch_size; i++) {
                generator_lists[i] = std::mt19937(seedsPtr[i]);
            }
        }
    }

    // 3.2. compute logits penalty
    if (std::any_of(temperature.data<float>(),
                    temperature.data<float>() + batch_size,
                    [&](auto t) { return t != 1.0f; }))
    {
        temperaturePenalty(
                logits.data<float>(),
                temperature.data<float>(),
                batch_size,
                vocab_size_padded
        );
    }

    const auto decoder_batch_size = params.sequence_lengths.shape()[0];
    if (decoder_batch_size) {
        if (step > 1 && params.repetition_penalty && decoder_batch_size) {
            auto& repetition_penalty = params.repetition_penalty->get();
            if (std::any_of(repetition_penalty.data<float>(),
                            repetition_penalty.data<float>() + batch_size,
                            [&](auto t) { return t != 1.0f; }))
            {
//                const auto repetition_penalty_type = RepetitionPenaltyType::Multiplicative;
                repetitionPenalty(logits.data<float>(),
                                  repetition_penalty.data<float>(),
                                  tokens.data<int32_t>(),
                                  batch_size,
                                  vocab_size_padded,
                                  params.sequence_lengths.data<int32_t>(),
                                  step + 1,
                                  step);
            }
        }
        if (params.min_lengths.has_value())
            if (params.min_lengths && params.eos_ids) {
                minLengthPenaltyNew(logits.data<float>(),
                                    params.min_lengths.value().get().data<int32_t>(),
                                    params.eos_ids.value().get().data<int32_t>(),
                                    params.sequence_lengths.data<int32_t>(),
                                    params.input_lengths.data<int32_t>(),
                                    batch_size,
                                    vocab_size_padded);
            }
    }

    // 4. run sampling
    // 4.1 run top_k
    setup_topk_runtime_args(batch_size,
                            default_top_k,
                            top_k.data<uint32_t>(),
                            batch_size,
                            default_top_p,
                            top_p.data<float>(),
                            batch_size,
                            skip_top_k_decode_buf->data<bool>());

    float* topk_logs = topk_tmp_val_buf->data<float>();
    int* topk_logs_indices = topk_tmp_id_buf->data<int>();

    topKKernel(topk_logs,
               topk_logs_indices,
               logits.data<float>(), batch_size,
               vocab_size_padded,
               max_top_k,
               top_k.data<uint32_t>());


    topk_sampling(batch_size, topk_logs_indices, topk_logs, generator_lists,
                  tokens.data<int32_t>(),
                  nullptr, // sequence_length
                  nullptr, // finished
                  nullptr, //          cum_log_probs,
                  nullptr, //         output_log_probs,
                  nullptr, //            token_id_for_index_prob,
                  max_top_k,
                  top_k.data<uint32_t>(),
                  1.0f,
                  top_p.data<float>(),
                  params.eos_ids.value().get().data<int32_t>(),
                  vocab_size_padded,
                  skip_top_k_decode_buf->data<bool>(),
                  step + 1);

    return GreedyOutput{};
}