maga_transformer/cpp/models/Sampler.cc (164 lines of code) (raw):

#include "maga_transformer/cpp/models/Sampler.h" #include "maga_transformer/cpp/devices/utils/DebugUtils.h" #include <unordered_set> using namespace std; namespace rtp_llm { Sampler::Sampler(const SamplerInitParams& params) : device_(params.device) { RTP_LLM_LOG_INFO("sampler max_batch_size: %ld", params.max_batch_size); const auto max_batch_size = params.max_batch_size; eos_ids_host_ = device_->allocateBuffer( {DataType::TYPE_INT32, {max_batch_size}, AllocationType::HOST}); std::fill_n(eos_ids_host_->data<int32_t>(), max_batch_size, params.eos_id); eos_ids_ = device_->allocateBuffer( {DataType::TYPE_INT32, {max_batch_size}, AllocationType::DEVICE}, {"eos_id"}); device_->copy({*eos_ids_, *eos_ids_host_}); }; SamplerOutput Sampler::forward(const SamplerInputs& inputs) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); size_t from_batch_idx = 0; size_t sample_to_batch_idx = 0; size_t from_seq_idx = 0; // accumulates batch_size * num_beams auto beam_sizes = inputs.num_beams->data<uint64_t>(); auto current_beam_size = beam_sizes[0]; const auto& input_tokens = *inputs.token_ids; auto success = device_->allocateBuffer({DataType::TYPE_BOOL, {inputs.batch_size}, AllocationType::HOST}); preprocessLogits(inputs); do { while (sample_to_batch_idx + 1 < inputs.batch_size && beam_sizes[sample_to_batch_idx + 1] == current_beam_size) { sample_to_batch_idx++; } // now from_batch_idx to sample_to_batch_idx have the same beam size, sample once. const auto sample_batch_size = sample_to_batch_idx - from_batch_idx + 1; const auto sample_seq_num = sample_batch_size; const auto sample_to_seq_idx = from_seq_idx + sample_seq_num; auto sample_tokens = input_tokens.view(from_seq_idx, sample_seq_num); auto sample_logits = inputs.logits->view(from_seq_idx, sample_seq_num); auto input_lengths = inputs.input_lengths->view(from_batch_idx, sample_batch_size); const auto decoder_batch_size = inputs.sequence_lengths->shape()[0]; auto sequence_lengths = from_batch_idx < decoder_batch_size ? inputs.sequence_lengths->view(from_batch_idx, min(sample_batch_size, decoder_batch_size - from_batch_idx)) : Buffer(MemoryType::MEMORY_CPU, DataType::TYPE_INT32, {0}, nullptr); BufferPtr sample_cum_log_probs; if (inputs.cum_log_probs) { sample_cum_log_probs = device_->allocateBuffer( {inputs.cum_log_probs->type(), {sample_seq_num}}); device_->copy({*sample_cum_log_probs, inputs.cum_log_probs->view(from_seq_idx, sample_seq_num)}); } #define MAY_GET_BUFFER_VIEW(buffer_ptr) \ (buffer_ptr.get() ? buffer_ptr->view(from_batch_idx, sample_batch_size) : Buffer::emptyBuffer()) if (current_beam_size == 1) { auto random_seeds = MAY_GET_BUFFER_VIEW(inputs.random_seeds); auto repetition_penalty = MAY_GET_BUFFER_VIEW(inputs.repetition_penalty); auto min_lengths = MAY_GET_BUFFER_VIEW(inputs.min_lengths); auto no_repeat_ngram_size = MAY_GET_BUFFER_VIEW(inputs.no_repeat_ngram_size); auto all_probs = (inputs.all_probs.get() ? inputs.all_probs->view(from_batch_idx, sample_seq_num) : Buffer::emptyBuffer()); auto greedy_output = device_->sampleGreedy({ sample_logits, input_lengths, sequence_lengths, sample_tokens, inputs.step, *inputs.top_k, *inputs.top_p, *inputs.temperature, inputs.random_seeds ? (OptionalBufferRef)random_seeds : nullopt, inputs.repetition_penalty ? (OptionalBufferRef)repetition_penalty : nullopt, inputs.min_lengths ? (OptionalBufferRef)min_lengths : nullopt, *eos_ids_, inputs.no_repeat_ngram_size ? (OptionalBufferRef)no_repeat_ngram_size : nullopt, inputs.cum_log_probs ? (OptionalBufferRef)*sample_cum_log_probs : nullopt, nullopt, // output_log_probs inputs.all_probs ? (OptionalBufferRef) all_probs: nullopt }); if (greedy_output.success) { device_->copy({success->view(from_seq_idx, sample_seq_num), *greedy_output.success}); } else { std::fill(success->dataWithOffset<bool>(from_seq_idx), success->dataWithOffset<bool>(from_seq_idx) + sample_seq_num, true); } } else { size_t beam_batch_size = (size_t)(sample_batch_size / current_beam_size); RTP_LLM_LOG_DEBUG("current_beam_size is %d", current_beam_size); RTP_LLM_LOG_DEBUG("current_beam_batch is %d", beam_batch_size); RTP_LLM_CHECK_WITH_INFO((sample_batch_size % current_beam_size == 0), "sample_batch_size[%d] must devide by current_beam_size[%d]"); auto beam_search_sequence_lengths = inputs.beam_search_sequence_lengths->view(from_batch_idx, sample_batch_size); auto beam_index = inputs.beam_index->view(from_batch_idx, sample_batch_size); auto org_sample_logits_shape = sample_logits.shape(); auto org_sample_tokens_shape = sample_tokens.shape(); auto org_input_lengths_shape = input_lengths.shape(); auto org_sequence_lengths_shape = beam_search_sequence_lengths.shape(); auto org_sample_cum_log_probs_shape = sample_cum_log_probs->shape(); auto org_beam_index_shape = beam_index.shape(); sample_logits.updateShape({beam_batch_size, (size_t)current_beam_size, (size_t)inputs.logits->shape()[1]}); sample_tokens.updateShape({beam_batch_size, (size_t)current_beam_size, (size_t)input_tokens.shape()[1]}); beam_search_sequence_lengths.updateShape({beam_batch_size, (size_t)current_beam_size}); sample_cum_log_probs->updateShape({beam_batch_size, (size_t)current_beam_size}); input_lengths.updateShape({beam_batch_size, (size_t)current_beam_size}); beam_index.updateShape({beam_batch_size, (size_t)current_beam_size}); auto sample_logits_device = device_->clone({sample_logits, AllocationType::DEVICE}); auto sample_tokens_device = device_->clone({sample_tokens, AllocationType::DEVICE}); auto input_lengths_device = device_->clone({input_lengths, AllocationType::DEVICE}); auto beam_search_sequence_lengths_device = device_->clone({beam_search_sequence_lengths, AllocationType::DEVICE}); auto sample_cum_log_probs_device = device_->clone({*sample_cum_log_probs, AllocationType::DEVICE}); auto beam_index_device = device_->clone({beam_index, AllocationType::DEVICE}); device_->sampleBeamSearch({*sample_logits_device, *sample_tokens_device, *input_lengths_device, *beam_search_sequence_lengths_device, *sample_cum_log_probs_device, *beam_index_device}); device_->copy({sample_logits, *sample_logits_device}); device_->copy({sample_tokens, *sample_tokens_device}); device_->copy({input_lengths, *input_lengths_device}); device_->copy({beam_search_sequence_lengths, *beam_search_sequence_lengths_device}); device_->copy({*sample_cum_log_probs, *sample_cum_log_probs_device}); device_->copy({beam_index, *beam_index_device}); sample_logits.updateShape(org_sample_logits_shape); sample_tokens.updateShape(org_sample_tokens_shape); beam_search_sequence_lengths.updateShape(org_sequence_lengths_shape); sample_cum_log_probs->updateShape(org_sample_cum_log_probs_shape); input_lengths.updateShape(org_input_lengths_shape); beam_index.updateShape(org_beam_index_shape); std::fill(success->dataWithOffset<bool>(from_seq_idx), success->dataWithOffset<bool>(from_seq_idx) + sample_seq_num, true); } if (inputs.cum_log_probs) { device_->copy({inputs.cum_log_probs->view(from_seq_idx, sample_seq_num), *sample_cum_log_probs}); } from_batch_idx = sample_to_batch_idx + 1; sample_to_batch_idx = from_batch_idx; from_seq_idx = sample_to_seq_idx; } while (from_batch_idx < inputs.batch_size); updateGrammarStatus(inputs); // TODO(xinfei.sxf) 优化copy token_ids return SamplerOutput({move(inputs.token_ids), move(inputs.cum_log_probs), move(inputs.all_probs), move(inputs.beam_index), move(success)}); } void Sampler::preprocessLogits(const SamplerInputs& inputs) { for (auto grammar: inputs.grammars) { grammar->process(inputs); } } void Sampler::updateGrammarStatus(const SamplerInputs& inputs) { for (auto grammar: inputs.grammars) { grammar->updateStatus(inputs); } } } // namespace rtp_llm