maga_transformer/cpp/kernels/no_aux_tc_kernels.cu (646 lines of code) (raw):

/* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "maga_transformer/cpp/cuda/cuda_type_utils.cuh" #include "no_aux_tc_kernels.h" #include <cooperative_groups.h> #include <cooperative_groups/reduce.h> namespace cg = cooperative_groups; namespace rtp_llm { constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int32_t WARP_SIZE = 32; constexpr int32_t BLOCK_SIZE = 512; constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; namespace warp_topk { template <int size, typename T> __host__ __device__ constexpr T round_up_to_multiple_of(T len) { if (len == 0) { return 0; } return ((len - 1) / size + 1) * size; } template <typename T> constexpr __host__ __device__ bool isPowerOf2(T v) { return (v && !(v & (v - 1))); } template <bool greater, typename T> __forceinline__ __device__ bool is_better_than(T val, T baseline) { return (val > baseline && greater) || (val < baseline && !greater); } template <bool greater, typename T, typename idxT> __forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, idxT baseline_index) { bool res = (val > baseline && greater) || (val < baseline && !greater); if (val == baseline) { res = (index < baseline_index && greater) || (index < baseline_index && !greater); } return res; } template <typename T, typename idxT> int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); return max(cache_topk, round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); } template <int size, bool ascending, bool reverse, typename T, typename idxT, bool is_stable> struct BitonicMerge { // input should be a bitonic sequence, and sort it to be a monotonic sequence __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { static_assert(isPowerOf2(size)); static_assert(size >= 2 * WARP_SIZE); constexpr int arr_len = size / WARP_SIZE; constexpr int stride = arr_len / 2; for (int i = 0; i < stride; ++i) { int const other_i = i + stride; T& val = val_arr[i]; T& other_val = val_arr[other_i]; bool is_better; if constexpr (is_stable) { is_better = is_better_than<ascending>(val, other_val, idx_arr[i], idx_arr[other_i]); } else { is_better = is_better_than<ascending>(val, other_val); } if (is_better) { T tmp = val; val = other_val; other_val = tmp; idxT tmp2 = idx_arr[i]; idx_arr[i] = idx_arr[other_i]; idx_arr[other_i] = tmp2; } } BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(val_arr, idx_arr); BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge( val_arr + arr_len / 2, idx_arr + arr_len / 2); } }; template <int size, bool ascending, typename T, typename idxT, bool is_stable> struct BitonicSort { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { static_assert(isPowerOf2(size)); static_assert(size >= 2 * WARP_SIZE); constexpr int arr_len = size / WARP_SIZE; BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr); BitonicSort<size / 2, false, T, idxT, is_stable>::sort(val_arr + arr_len / 2, idx_arr + arr_len / 2); BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(val_arr, idx_arr); } }; template <bool ascending, typename T, typename idxT, bool is_stable> struct BitonicSort<32, ascending, T, idxT, is_stable> { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; // ascending doesn't matter before merging since all we need is a bitonic sequence for (int stage = 0; stage < 4; ++stage) { for (int stride = (1 << stage); stride > 0; stride /= 2) { bool reverse = (lane >> stage) & 2; bool is_second = lane & stride; T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); bool is_better; if constexpr (is_stable) { if constexpr (ascending) { is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr < other_idx))) != (reverse != is_second); } else { is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr > other_idx))) != (reverse != is_second); } } else { // is_better = (*val_arr != other) && is_better_than(*val_arr, other, (reverse == is_second)); is_better = (*val_arr != other && (*val_arr > other) != (reverse != is_second)); } if (is_better) { *val_arr = other; *idx_arr = other_idx; } } } BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, idx_arr); } }; template <bool ascending, bool reverse, typename T, typename idxT, bool is_stable> struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { bool is_second = lane & stride; T& val = *val_arr; T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); idxT& idx = *idx_arr; idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); bool is_better; if constexpr (is_stable) { if constexpr (ascending) { is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr < other_idx))) == (reverse != is_second); // for min } else { is_better = ((*val_arr > other) || ((*val_arr == other) && (*idx_arr > other_idx))) == (reverse != is_second); // for max } } else { // is_better = (val != other) && (is_better_than(val, other, (reverse != is_second))); is_better = (val != other && ((val > other) == (ascending != is_second))); } if (is_better) { val = other; idx = other_idx; } } } }; template <int capacity, bool greater, typename T, typename idxT, bool is_stable> class WarpSort { public: __device__ WarpSort(idxT k, T dummy) : lane_(threadIdx.x % WARP_SIZE) , k_(k) , dummy_(dummy) { static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); for (int i = 0; i < max_arr_len_; ++i) { val_arr_[i] = dummy_; idx_arr_[i] = 0; } } // load and merge k sorted values __device__ void load_sorted(T const* __restrict__ in, idxT const* __restrict__ in_idx, idxT start) { idxT idx = start + WARP_SIZE - 1 - lane_; for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { if (idx < start + k_) { T t = in[idx]; bool is_better; if constexpr (is_stable) { is_better = is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]); } else { is_better = is_better_than<greater>(t, val_arr_[i]); } if (is_better) { val_arr_[i] = t; idx_arr_[i] = in_idx[idx]; } } } BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(val_arr_, idx_arr_); } __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { for (int i = 0; i < max_arr_len_; ++i) { idxT out_i = i * WARP_SIZE + lane_; if (out_i < k_) { out[out_i] = val_arr_[i]; out_idx[out_i] = idx_arr_[i]; } } } __device__ void dumpIdx(idxT* __restrict__ out_idx) const { for (int i = 0; i < max_arr_len_; ++i) { idxT out_i = i * WARP_SIZE + lane_; if (out_i < k_) { out_idx[out_i] = idx_arr_[i]; } } } protected: static constexpr int max_arr_len_ = capacity / WARP_SIZE; T val_arr_[max_arr_len_]; idxT idx_arr_[max_arr_len_]; int const lane_; idxT const k_; T const dummy_; }; // end class WarpSort template <int capacity, bool greater, typename T, typename idxT, bool is_stable> class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> { public: __device__ WarpSelect(idxT k, T dummy) : WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy) , k_th_(dummy) , k_th_lane_((k - 1) % WARP_SIZE) { extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; int const num_of_warp = blockDim.x / WARP_SIZE; int const warp_id = threadIdx.x / WARP_SIZE; val_smem_ = reinterpret_cast<T*>(smem_buf); val_smem_ += warp_id * WARP_SIZE; idx_smem_ = reinterpret_cast<idxT*>(smem_buf + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); idx_smem_ += warp_id * WARP_SIZE; } __device__ void add(T const* in, idxT start, idxT end) { idxT const end_for_fullwarp = round_up_to_multiple_of<WARP_SIZE>(end - start) + start; for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { T val = (i < end) ? in[i] : dummy_; add(val, i); } } __device__ void add(T val, idxT idx) { bool do_add; if constexpr (is_stable) { do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_); } else { do_add = is_better_than<greater>(val, k_th_); } uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); if (mask == 0) { return; } int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); if (do_add && pos < WARP_SIZE) { val_smem_[pos] = val; idx_smem_[pos] = idx; do_add = false; } smem_buf_len_ += __popc(mask); if (smem_buf_len_ >= WARP_SIZE) { __syncwarp(); merge_buf_(val_smem_[lane_], idx_smem_[lane_]); smem_buf_len_ -= WARP_SIZE; } if (do_add) { pos -= WARP_SIZE; val_smem_[pos] = val; idx_smem_[pos] = idx; } __syncwarp(); } __device__ void done() { if (smem_buf_len_) { T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; merge_buf_(val, idx); } // after done(), smem is used for merging results among warps __syncthreads(); } private: __device__ void set_k_th_() { k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); if constexpr (is_stable) { k_th_idx_ = __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); } } __device__ void merge_buf_(T val, idxT idx) { BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx); T& old = val_arr_[max_arr_len_ - 1]; bool is_better; if constexpr (is_stable) { is_better = is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]); } else { is_better = is_better_than<greater>(val, old); } if (is_better) { old = val; idx_arr_[max_arr_len_ - 1] = idx; } BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(val_arr_, idx_arr_); set_k_th_(); } using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_; using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_; using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_; using WarpSort<capacity, greater, T, idxT, is_stable>::lane_; using WarpSort<capacity, greater, T, idxT, is_stable>::k_; using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_; T* val_smem_; idxT* idx_smem_; int smem_buf_len_ = 0; T k_th_; idxT k_th_idx_; int const k_th_lane_; }; // end class WarpSelect } // namespace warp_topk template <typename T> __device__ void topk_with_k2(T* output, T const* input, cg::thread_block_tile<32> const& tile, int32_t const lane_id, int const num_experts_per_group) { // Get the top2 per thread T largest = -INFINITY; T second_largest = -INFINITY; if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { T value = input[i]; if (value > largest) { second_largest = largest; largest = value; } else if (value > second_largest) { second_largest = value; } } } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { largest = input[i]; } } __syncwarp(); // Ensure all threads have valid data before reduction // Get the top2 warpwise T max1 = cg::reduce(tile, largest, cg::greater<T>()); T max2 = max1; bool equal_to_max1 = (max1 == largest); int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); if (count_max1 == 1) { largest = (largest == max1) ? second_largest : largest; max2 = cg::reduce(tile, largest, cg::greater<T>()); } if (lane_id == 0) { *output = max1 + max2; } } template <typename T> __global__ void topk_with_k2_kernel(T* output, T* input, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, int64_t const num_experts_per_group) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; if (case_id < num_cases) { input += case_id * num_experts_per_group; output += case_id; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif topk_with_k2(output, input, tile, lane_id, num_experts_per_group); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } template <typename T, typename IdxT> __global__ void group_idx_and_topk_idx_kernel(T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, int norm_node, double routed_scaling_factor) { // TODO: norm_node int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token scores_with_bias += case_id * num_experts; scores += case_id * num_experts; group_scores += case_id * n_group; topk_values += case_id * topk; topk_indices += case_id * topk; int32_t align_num_experts_per_group = warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group); cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to store the target topk idx int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf); T* s_topk_value = reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + warp_id * topk; s_topk_idx += warp_id * topk; T value = cuda::std::numeric_limits<T>::min(); T topk_group_value = cuda::std::numeric_limits<T>::min(); int32_t num_equalto_topkth_group; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before acqbulk because it's ptr arithmetic #endif if (case_id < num_tokens) { // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; if (lane_id < n_group && (isfinite(cuda_cast<float, T>(group_scores[lane_id])))) // The check is necessary to avoid abnormal input { value = group_scores[lane_id]; } int count_equal_to_top_value = WARP_SIZE - n_group; int pre_count_equal_to_top_value = 0; // Use loop to find the largset top_group while (count_equal_to_top_value < target_num_min) { __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater<T>()); if (value == topk_group_value) { value = cuda::std::numeric_limits<T>::min(); } pre_count_equal_to_top_value = count_equal_to_top_value; count_equal_to_top_value = __popc(__ballot_sync(FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min()))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } __syncthreads(); warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t, /* is_stable */ true> queue( (int32_t) topk, -INFINITY); int count_equalto_topkth_group = 0; bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min()); if (case_id < num_tokens && if_proceed_next_topk) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || ((group_scores[i_group] == topk_group_value) && (count_equalto_topkth_group < num_equalto_topkth_group))) { int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { T candidates = (i < num_experts_per_group) && isfinite(cuda_cast<float, T>(scores_with_bias[offset + i])) ? scores_with_bias[offset + i] : cuda::std::numeric_limits<T>::min(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { count_equalto_topkth_group++; } } } queue.done(); __syncwarp(); // Get the topk_idx queue.dumpIdx(s_topk_idx); __syncwarp(); } // Load the valid score value // Calculate the summation float topk_sum = norm_node == 0 ? 1.0 : 1e-20; if (case_id < num_tokens && if_proceed_next_topk) { for (int i = lane_id; i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk); i += WARP_SIZE) { T value = i < topk ? scores[s_topk_idx[i]] : cuda_cast<T, float>(0.0f); // Load the valid value of expert if (i < topk) { s_topk_value[i] = value; } if (norm_node == 1) { topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>()); } } } __syncthreads(); if (case_id < num_tokens) { if (if_proceed_next_topk) { for (int i = lane_id; i < topk; i += WARP_SIZE) { float value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum * routed_scaling_factor; topk_indices[i] = s_topk_idx[i]; topk_values[i] = cuda_cast<T, float>(value); } } else { for (int i = lane_id; i < topk; i += WARP_SIZE) { topk_indices[i] = i; topk_values[i] = cuda_cast<T, float>(1.0f / topk); } } // Note: when if_proceed_next_topk==false, choose the first 8 experts as the default result. //@TODO: check if this default strategy is acceptable. Might need to leave it as nan array. } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } template <typename T, typename IdxT> void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, IdxT* topk_indices, T* scores_with_bias, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, int norm_node, double const routed_scaling_factor, cudaStream_t const stream) { int device; cudaGetDevice(&device); int cc_major; cudaDeviceGetAttribute(&cc_major, cudaDeviceAttr::cudaDevAttrComputeCapabilityMajor, device); #define LAUNCH_KERNEL(kernel, grid_dim, block_dim, smem_size, stream, ...) \ do { \ if (cc_major >= 9) { \ cudaLaunchConfig_t config__; \ config__.gridDim = (grid_dim); \ config__.blockDim = (block_dim); \ config__.dynamicSmemBytes = (smem_size); \ config__.stream = (stream); \ cudaLaunchAttribute attrs__[1]; \ attrs__[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ attrs__[0].val.programmaticStreamSerializationAllowed = true; \ config__.numAttrs = 1; \ config__.attrs = attrs__; \ cudaLaunchKernelEx(&config__, &(kernel), __VA_ARGS__); \ } else { \ (kernel)<<<(grid_dim), (block_dim), (smem_size), (stream)>>>(__VA_ARGS__); \ } \ } while(0) int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; LAUNCH_KERNEL((topk_with_k2_kernel<T>), topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream, group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group); sync_check_cuda_error(); int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; size_t dynamic_smem_in_bytes = warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK, topk); LAUNCH_KERNEL((group_idx_and_topk_idx_kernel<T, IdxT>), topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream, scores, group_scores, topk_values, topk_indices, scores_with_bias, num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, norm_node, routed_scaling_factor); sync_check_cuda_error(); #undef LAUNCH_KERNEL } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ template void invokeNoAuxTc<T, IdxT>(T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ T * scores_with_bias, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \ int64_t const topk_group, int64_t const topk, int norm_node, double const routed_scaling_factor, \ cudaStream_t const stream); INSTANTIATE_NOAUX_TC(float, int32_t); // INSTANTIATE_NOAUX_TC(half, int32_t); // #ifdef ENABLE_BF16 // INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); // #endif } // namespace rtp_llm