maga_transformer/cpp/kernels/sampling_penalty_kernels.cu (607 lines of code) (raw):

/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <assert.h> #include <float.h> //#include "driver_types.h" #include "maga_transformer/cpp/kernels/sampling_penalty_kernels.h" namespace rtp_llm { // TODO Add half2 implementation template<typename T> __global__ void applyTemperaturePenalty(T* logits, const T* bias, const float temperature_inverse, const int m, const int vocab_size, const int vocab_size_padd) { const bool IS_FP16 = std::is_same<T, half>::value; const T MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < m * vocab_size_padd; index += blockDim.x * gridDim.x) { T bias_val = bias == nullptr ? (T)(0.0f) : bias[index % vocab_size_padd]; if (index % vocab_size_padd < vocab_size) { logits[index] = (logits[index] + bias_val) * (T)temperature_inverse; } else { logits[index] = -MAX_T_VAL; } } } template<> __global__ void applyTemperaturePenalty(half2* logits, const half2* bias, const float temperature_inverse, const int batch_size, const int vocab_size, const int vocab_size_padded) { assert(vocab_size % 2 == 0); assert(vocab_size_padded % 2 == 0); const half2 mask_val = __float2half2_rn(-65504.0f); const half2 temp_inv = __float2half2_rn(temperature_inverse); const int half_vocab_size = vocab_size / 2; const int half_vocab_size_padded = vocab_size_padded / 2; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded; index += blockDim.x * gridDim.x) { int vocab_idx = index % half_vocab_size_padded; half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val; if (vocab_idx < half_vocab_size) { if (bias != nullptr) { logit = __hadd2(logit, bias[vocab_idx]); } logits[index] = __hmul2(logit, temp_inv); } } } template<typename T> void invokeApplyTemperaturePenalty(T* logits, const T* bias, const float temperature, const int batch_size, const int vocab_size, const int vocab_size_padd, cudaStream_t stream) { dim3 block(min(vocab_size_padd, 1024)); dim3 grid(min(batch_size * vocab_size_padd / block.x, 65536)); const T temperature_inverse = (T)(1.f / (temperature + 1e-6f)); if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) { applyTemperaturePenalty<<<grid, block, 0, stream>>>(reinterpret_cast<half2*>(logits), reinterpret_cast<const half2*>(bias), temperature_inverse, batch_size, vocab_size, vocab_size_padd); } else { applyTemperaturePenalty<T> <<<grid, block, 0, stream>>>(logits, bias, temperature_inverse, batch_size, vocab_size, vocab_size_padd); } } template void invokeApplyTemperaturePenalty(float* logits, const float* bias, const float temperature, const int batch_size, const int vocab_size, const int vocab_size_padd, cudaStream_t stream); template void invokeApplyTemperaturePenalty(half* logits, const half* bias, const float temperature, const int batch_size, const int vocab_size, const int vocab_size_padd, cudaStream_t stream); template<typename T> __global__ void batchApplyTemperaturePenalty(T* logits, const T* bias, const float* temperatures, const int batch_size, const int vocab_size, const int vocab_size_padd) { // TODO: Add macro or device function to get MAX_T_VAL. const bool IS_FP16 = std::is_same<T, half>::value; const T MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX; extern __shared__ float inv_temperatures[]; if (threadIdx.x < batch_size) { inv_temperatures[threadIdx.x] = 1.0f / (temperatures[threadIdx.x] + 1e-6f); } __syncthreads(); for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * vocab_size_padd; index += blockDim.x * gridDim.x) { int batch_idx = index / vocab_size_padd; int vocab_idx = index % vocab_size_padd; T logit = (vocab_idx < vocab_size) ? logits[index] : -MAX_T_VAL; if (vocab_idx < vocab_size) { if (bias != nullptr) { logit += bias[vocab_idx]; } logit *= inv_temperatures[batch_idx]; } logits[index] = logit; } } __global__ void batchApplyTemperaturePenalty_h2(half2* logits, const half2* bias, const float* temperatures, const int batch_size, const int vocab_size, const int vocab_size_padded) { assert(vocab_size % 2 == 0); assert(vocab_size_padded % 2 == 0); extern __shared__ half2 h2_inv_temperatures[]; if (threadIdx.x < batch_size) { h2_inv_temperatures[threadIdx.x] = __float2half2_rn(1.f / (temperatures[threadIdx.x] + 1e-6f)); } __syncthreads(); const half2 mask_val = __float2half2_rn(-65504.0f); const int half_vocab_size = vocab_size / 2; const int half_vocab_size_padded = vocab_size_padded / 2; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded; index += blockDim.x * gridDim.x) { int batch_idx = index / half_vocab_size_padded; int vocab_idx = index % half_vocab_size_padded; half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val; if (vocab_idx < half_vocab_size) { if (bias != nullptr) { logit = __hadd2(logit, bias[vocab_idx]); } logits[index] = __hmul2(logit, h2_inv_temperatures[batch_idx]); } } } template<typename T> void invokeBatchApplyTemperaturePenalty(T* logits, const T* bias, const float* temperatures, const int batch_size, const int vocab_size, const int vocab_size_padd, cudaStream_t stream) { dim3 block(min(vocab_size_padd, 1024)); dim3 grid(min(batch_size * vocab_size_padd / block.x, 65536)); if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) { size_t smem_size = sizeof(half2) * batch_size; batchApplyTemperaturePenalty_h2<<<grid, block, smem_size, stream>>>(reinterpret_cast<half2*>(logits), reinterpret_cast<const half2*>(bias), temperatures, batch_size, vocab_size, vocab_size_padd); } else { size_t smem_size = sizeof(float) * batch_size; batchApplyTemperaturePenalty<T> <<<grid, block, smem_size, stream>>>(logits, bias, temperatures, batch_size, vocab_size, vocab_size_padd); } } template void invokeBatchApplyTemperaturePenalty(float* logits, const float* bias, const float* temperatures, const int batch_size, const int vocab_size, const int vocab_size_padd, cudaStream_t stream); template void invokeBatchApplyTemperaturePenalty(half* logits, const half* bias, const float* temperatures, const int batch_size, const int vocab_size, const int vocab_size_padd, cudaStream_t stream); template<typename T, RepetitionPenaltyType penalty_type> __global__ void applyRepetitionPenalty(T* logits, const float penalty, const int* start_ids, int* output_ids, const int batch_size, const int local_batch_size, const int vocab_size, const int vocab_size_padd, const int* input_lengths, const int max_input_len, const int step) { extern __shared__ float penalty_logits[]; int* penalty_indices = (int*)(penalty_logits + step); logits = logits + blockIdx.x * vocab_size_padd; const int input_length = input_lengths != nullptr ? input_lengths[blockIdx.x] : max_input_len; for (int index = threadIdx.x; index < step; index += blockDim.x) { if (index >= input_length && index < max_input_len) { continue; } // output_ids shape: (input_len + output_len, batch_size) int penalty_index = output_ids[index * batch_size + blockIdx.x]; if (penalty_index >= vocab_size) { continue; } penalty_indices[index] = penalty_index; float logit = (float)logits[penalty_index]; if (penalty_type == RepetitionPenaltyType::Additive) { penalty_logits[index] = logit - penalty; } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty; } else if (penalty_type == RepetitionPenaltyType::None) { penalty_logits[index] = logit; } else { // Unsupported type assert(false); } } if (blockDim.x > 32) { __syncthreads(); } for (int index = threadIdx.x; index < step; index += blockDim.x) { if (index >= input_length && index < max_input_len) { continue; } // output_ids shape: (input_len + output_len, batch_size) if (penalty_indices[index] >= vocab_size) { continue; } logits[penalty_indices[index]] = penalty_logits[index]; } } template<typename T> void invokeApplyRepetitionPenalty(T* logits, const float penalty, const int* start_ids, int* output_ids, const int batch_size, const int local_batch_size, const int vocab_size, const int vocab_size_padd, const int* input_lengths, const int max_input_len, const int step, const RepetitionPenaltyType penalty_type, cudaStream_t stream) { dim3 block(min(step, 1024)); dim3 grid(local_batch_size); size_t smem_size = step * (sizeof(float) + sizeof(int)); if (penalty_type == RepetitionPenaltyType::Additive) { applyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(logits, penalty, start_ids, output_ids, batch_size, local_batch_size, vocab_size, vocab_size_padd, input_lengths, max_input_len, step); } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { applyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative> <<<grid, block, smem_size, stream>>>(logits, penalty, start_ids, output_ids, batch_size, local_batch_size, vocab_size, vocab_size_padd, input_lengths, max_input_len, step); } else if (penalty_type == RepetitionPenaltyType::None) { // do nothing } } template void invokeApplyRepetitionPenalty(float* logits, const float penalty, const int* start_ids, int* output_ids, const int batch_size, const int local_batch_size, const int vocab_size, const int vocab_size_padd, const int* input_lengths, const int max_input_len, const int step, const RepetitionPenaltyType penalty_type, cudaStream_t stream); template void invokeApplyRepetitionPenalty(half* logits, const float penalty, const int* start_ids, int* output_ids, const int batch_size, const int local_batch_size, const int vocab_size, const int vocab_size_padd, const int* input_lengths, const int max_input_len, const int step, const RepetitionPenaltyType penalty_type, cudaStream_t stream); template<typename T, RepetitionPenaltyType penalty_type> __global__ void batchApplyRepetitionPenalty(T* logits, const float* penalties, const int* output_ids, const int batch_size, const int vocab_size, const int* input_lengths, const int max_input_length, const int step) { extern __shared__ float penalty_logits[]; int* penalty_indices = (int*)(penalty_logits + step); const int batch_idx = blockIdx.x; const float penalty = penalties[batch_idx]; const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; logits += batch_idx * vocab_size; if (penalty == 1.0f && penalty_type == RepetitionPenaltyType::Multiplicative) { return; } // Phase 1. Find indices to penalize and keep the penalized values. // A vocab id can appear multiple times but should be penalized once. for (int index = threadIdx.x; index < step; index += blockDim.x) { // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; } // output_ids shape: (input_len + output_len, batch_size) int penalty_index = output_ids[index * batch_size + batch_idx]; if (penalty_index >= vocab_size || penalty_index < 0) { continue; } float logit = (float)logits[penalty_index]; penalty_indices[index] = penalty_index; if (penalty_type == RepetitionPenaltyType::Additive) { penalty_logits[index] = logit - penalty; } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty; } else if (penalty_type == RepetitionPenaltyType::None) { penalty_logits[index] = logit; } else { // Unsupported type assert(false); } } if (blockDim.x > 32) { __syncthreads(); } // Phase 2. Replace a logit value by the penalized one. for (int index = threadIdx.x; index < step; index += blockDim.x) { // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; } if (penalty_indices[index] >= 0 && penalty_indices[index] < vocab_size) { logits[penalty_indices[index]] = penalty_logits[index]; } } } template<typename T, RepetitionPenaltyType penalty_type> __global__ void batchApplyRepetitionPenaltyLongSeq(T* logits, T* penalty_logits, const float* penalties, const int* output_ids, const int batch_size, const int vocab_size, const int* input_lengths, const int max_input_length, const int step) { const int batch_idx = blockIdx.x; const float penalty = penalties[batch_idx]; const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; if (penalty == 1.0f && penalty_type == RepetitionPenaltyType::Multiplicative) { return; } logits += batch_idx * vocab_size; // Phase 1. Find indices to penalize and keep the penalized values. // A vocab id can appear multiple times but should be penalized once. for (int index = threadIdx.x; index < step; index += blockDim.x) { // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; } // output_ids shape: (input_len + output_len, batch_size) int penalty_index = output_ids[index * batch_size + batch_idx]; if (penalty_index >= vocab_size || penalty_index < 0) { continue; } float logit = (float)logits[penalty_index]; if (penalty_type == RepetitionPenaltyType::Additive) { penalty_logits[index * batch_size + batch_idx] = logit - penalty; } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { penalty_logits[index * batch_size + batch_idx] = logit < 0.0f ? logit * penalty : logit / penalty; } else if (penalty_type == RepetitionPenaltyType::None) { penalty_logits[index * batch_size + batch_idx] = logit; } else { // Unsupported type assert(false); } } // Phase 2. Replace a logit value by the penalized one. for (int index = threadIdx.x; index < step; index += blockDim.x) { // Skip the padding tokens in input sequences. if (index >= input_length && index < max_input_length) { continue; } int penalty_index = output_ids[index * batch_size + batch_idx]; if (penalty_index < vocab_size && penalty_index >= 0) { logits[penalty_index] = penalty_logits[index * batch_size + batch_idx]; } } } template<typename T> void invokeBatchApplyRepetitionPenalty(T* logits, T* penalty_logits, const float* penalties, const int* output_ids, const int batch_size, const int local_batch_size, const int vocab_size, const int* input_lengths, const int max_input_length, const int step, RepetitionPenaltyType penalty_type, cudaStream_t stream) { // Inputs // logits [local_batch_size, vocab_size] : logit values. // penalties [local_batch_size] : repetition penalty factors. // output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size). // input_lengths [local_batch_size], input lengths (optional). // Padding tokens at [input_length, max_input_length) of input will not be penalized. dim3 block(min(step, 1024)); dim3 grid(local_batch_size); size_t smem_size = step * (sizeof(float)+ sizeof(int)); if (penalty_type == RepetitionPenaltyType::Additive) { if (smem_size < 46 * 1024) { batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } else { batchApplyRepetitionPenaltyLongSeq<T, RepetitionPenaltyType::Additive><<<grid, block, 0, stream>>>( logits, penalty_logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { if (smem_size < 46 * 1024) { batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } else { batchApplyRepetitionPenaltyLongSeq<T, RepetitionPenaltyType::Multiplicative><<<grid, block, 0, stream>>>( logits, penalty_logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } } else if (penalty_type == RepetitionPenaltyType::None) { // do nothing } } template<typename T> __global__ void ApplyCopyLogits( float* output_logits_buf, int* logit_index_buf, T* runtime_logits_buf, bool* skip_decode_buf_, const int local_batch_size, const int vocab_size_padded_) { int bid = blockIdx.x; if (skip_decode_buf_[bid]) { return; } for (int i = 0; i < local_batch_size; i++ ) { output_logits_buf[bid] = logf((float)runtime_logits_buf[bid * vocab_size_padded_ + logit_index_buf[bid]]); } } template<typename T> void invokeCopyLogits(float* output_logits_buf, int* logit_index_buf, T* runtime_logits_buf, bool* skip_decode_buf_, const int local_batch_size, const int vocab_size_padded_, cudaStream_t stream) { dim3 block(1); dim3 grid(local_batch_size); ApplyCopyLogits<<<grid, block, 0, stream>>>(output_logits_buf, logit_index_buf, runtime_logits_buf, skip_decode_buf_, local_batch_size, vocab_size_padded_); } #define INSTANTINVOKECOPYLOGITS(T) \ template void invokeCopyLogits(float* output_logits_buf, \ int* logit_index_buf, \ T* runtime_logits_buf, \ bool* skip_decode_buf_, \ const int local_batch_size, \ const int vocab_size_padded_, \ cudaStream_t stream); INSTANTINVOKECOPYLOGITS(float); INSTANTINVOKECOPYLOGITS(half); template void invokeBatchApplyRepetitionPenalty(float* logits, float* penalty_logits, const float* penalties, const int* output_ids, const int batch_size, const int local_batch_size, const int vocab_size, const int* input_lengths, const int max_input_length, const int step, RepetitionPenaltyType penalty_type, cudaStream_t stream); template void invokeBatchApplyRepetitionPenalty(half* logits, half* penalty_logits, const float* penalties, const int* output_ids, const int batch_size, const int local_batch_size, const int vocab_size, const int* input_lengths, const int max_input_length, const int step, RepetitionPenaltyType penalty_type, cudaStream_t stream); template<typename T> __global__ void batchApplyMinLengthPenaltyNew(T* logits, const int* min_lengths, const int* end_ids, const int* sequence_lengths, const int* input_lengths, const int vocab_size_padded, const int decoder_batch_size) { int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index if ((bid < decoder_batch_size && sequence_lengths[bid] - input_lengths[bid] < min_lengths[bid]) || (bid >= decoder_batch_size && min_lengths[bid] > 1)) { T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX; logits[bid * vocab_size_padded + end_ids[bid]] = mask_val; } } template<typename T> void invokeMinLengthPenaltyNew(T* logits, const int* min_lengths, const int* end_ids, const int* sequnece_lengths, const int* input_lengths, const int decoder_batch_size, const int batch_size, const int vocab_size_padded, cudaStream_t stream) { const int block_size = min(batch_size, 1024); const int grid_size = (batch_size + block_size - 1) / block_size; batchApplyMinLengthPenaltyNew<<<grid_size, block_size, 0, stream>>>( logits, min_lengths, end_ids, sequnece_lengths, input_lengths, vocab_size_padded, decoder_batch_size); } template void invokeMinLengthPenaltyNew(float* logits, const int* min_lengths, const int* end_ids, const int* sequnece_lengths, const int* input_lengths, const int decoder_batch_size, const int context_batch_size, const int vocab_size_padded, cudaStream_t stream); } // namespace rtp_llm