maga_transformer/cpp/kernels/activation_kernels.cu (711 lines of code) (raw):

/* * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "maga_transformer/cpp/kernels/activation_kernels.h" #include "maga_transformer/cpp/cuda/cuda_type_utils.cuh" #if USING_ROCM using namespace rtp_llm::rocm; #endif #include "maga_transformer/cpp/cuda/memory_utils.h" #if USING_CUDA #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #endif #endif namespace rtp_llm { /* Gelu Activation */ __forceinline__ __device__ float copysignf_pos(float a, float b) { float r; r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); return r; } __inline__ __device__ float tanh_opt(float x) { #if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) float r; asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); return r; #else const float exp_val = -1.f * fabs(2 * x); return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); #endif } template<typename T> struct GeluActivation { using return_type = T; static __device__ __forceinline__ T apply(const T& val) { const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val)))); return val * cdf; } }; template<typename T> struct GeluActivationNoneApproximate { using return_type = T; static __device__ __forceinline__ T apply(const T& val) { return 0.5f * val * (1.0f + std::erf(val * M_SQRT1_2)); } }; template<> struct GeluActivation<half2> { using return_type = half2; static __device__ __forceinline__ half2 apply(const half2& val) { half2 val_pow3 = __hmul2(val, __hmul2(val, val)); float2 tmp_pow = __half22float2(val_pow3); float2 tmp = __half22float2(val); tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); return __hmul2(val, __float22half2_rn(tmp)); } }; template<> struct GeluActivationNoneApproximate<half2> { using return_type = half2; static __device__ __forceinline__ half2 apply(const half2& val) { half2 val_pow3 = __hmul2(val, __hmul2(val, val)); float2 tmp_pow = __half22float2(val_pow3); float2 tmp = __half22float2(val); tmp.x = 0.5f * (1.0f + std::erf(tmp.x * M_SQRT1_2)); tmp.y = 0.5f * (1.0f + std::erf(tmp.y * M_SQRT1_2)); return __hmul2(val, __float22half2_rn(tmp)); } }; #ifdef ENABLE_BF16 template<> struct GeluActivation<__nv_bfloat162> { using return_type = __nv_bfloat162; static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val) { __nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val)); float2 tmp_pow = bf1622float2(val_pow3); float2 tmp = bf1622float2(val); tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y)); } }; template<> struct GeluActivationNoneApproximate<__nv_bfloat162> { using return_type = __nv_bfloat162; static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val) { __nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val)); float2 tmp_pow = bf1622float2(val_pow3); float2 tmp = bf1622float2(val); tmp.x = 0.5f * (1.0f + std::erf(tmp.x * M_SQRT1_2));; tmp.y = 0.5f * (1.0f + std::erf(tmp.y * M_SQRT1_2)); return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y)); } }; #endif /* Relu Activation */ template<typename T> struct ReluActivation { using return_type = T; static __device__ __forceinline__ T apply(const T& val) { return val > static_cast<T>(0.0f) ? val : static_cast<T>(0.0f); } }; template<> struct ReluActivation<half2> { using return_type = half2; static __device__ __forceinline__ half2 apply(const half2& val) { const half zero_half = static_cast<half>(0.0f); return make_half2(val.x > zero_half ? val.x : zero_half, val.y > zero_half ? val.y : zero_half); } }; #ifdef ENABLE_BF16 template<> struct ReluActivation<__nv_bfloat162> { using return_type = __nv_bfloat162; static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val) { const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f); return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16); } }; #endif /* Silu Activation */ template<typename T> struct SiluActivation { using return_type = T; static __device__ __forceinline__ T apply(const T& val) { return (T)((float)val / (1.0f + __expf((float)-val))); } }; template<> struct SiluActivation<half2> { using return_type = float2; static __device__ __forceinline__ float2 apply(const half2& val) { return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y)); } }; #ifdef ENABLE_BF16 template<> struct SiluActivation<__nv_bfloat162> { using return_type = float2; static __device__ __forceinline__ float2 apply(const __nv_bfloat162& val) { return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y)); } }; #endif // ENABLE_BF16 /* Identity Activation (= no activation) */ template<typename T> struct IdentityActivation { using return_type = T; static __device__ __forceinline__ T apply(const T& val) { return val; } }; // clang-format off template<template<typename T> class Activation, typename T, typename BT> __global__ void generic_activation(T* up_out, const BT* __restrict bias, const T* __restrict gate, const BT* __restrict gate_bias, const int* __restrict ia3_tasks, const T* __restrict ia3_weights, const int int8_mode, const float* __restrict activation_in, const float* __restrict activation_out, const BT* __restrict activation_scale, const int* __restrict padding_offset, const int seq_len, int m, int n, int total) { constexpr size_t packed_elems = num_elems<T>::value; const bool with_bias = bias != nullptr; const bool with_gate = gate != nullptr; const bool with_ia3 = ia3_tasks != nullptr; const bool with_act_scale = activation_scale != nullptr; using Act_T = typename Activation<T>::return_type; using Float_T = typename packed_as<float, packed_elems>::type; using Packed_Int8_t = typename packed_as<int8_t, packed_elems>::type; for (int64_t id = blockIdx.x * blockDim.x + threadIdx.x; id < total; id += blockDim.x * gridDim.x) { T val; if (int8_mode == 2) { val = cuda_cast<T>(cuda_cast<Float_T>(reinterpret_cast<Packed_Int8_t*>(up_out)[id]) * activation_in[0]); } else { val = up_out[id]; } T gate_val; if (with_gate) { gate_val = gate[id]; } if (with_bias) { const T reg_bias = static_cast<T>(bias[id % n]); val = val + reg_bias; if (with_gate) { const T reg_gated_bias = static_cast<T>(gate_bias[id % n]); gate_val = gate_val + reg_gated_bias; } } if (with_gate) { val = cuda_cast<T>(Activation<T>::apply(gate_val) * cuda_cast<Act_T>(val)); } else { val = cuda_cast<T>(Activation<T>::apply(val)); } if (with_ia3) { const int word_id = id / n; const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; const int batch_id = (word_id + offset) / seq_len; const int task = ia3_tasks[batch_id]; val = val * ia3_weights[task * n + (id % n)]; } if (with_act_scale) { const T reg_activation = static_cast<T>(activation_scale[id % n]); val = val / reg_activation; } if (int8_mode != 2 ) { up_out[id] = val; } else { reinterpret_cast<Packed_Int8_t*>(up_out)[id] = cuda_cast<Packed_Int8_t>(cuda_cast<Float_T>(val) * activation_out[0]); } } } // clang-format on template<template<typename T> class Activation, typename T, typename BT> void invokeGenericActivation(T* up_out, const BT* bias, const T* gate, const BT* gate_bias, const int* ia3_tasks, const T* ia3_weights, const int m, const int n, const int int8_mode, const float* activation_in, const float* activation_out, const BT* activation_scale, const int* padding_offset, const int seq_len, cudaStream_t stream) { using PT = typename packed_type_2<T>::type; constexpr int packed_elems = num_elems<PT>::value; using PBT = typename packed_as<BT, packed_elems>::type; // should be even int temp_n = n + n % 2; dim3 block, grid; constexpr int max_threads_per_block = 1024; constexpr int elems_per_thread = 4 * packed_elems; if (temp_n / elems_per_thread <= max_threads_per_block) { block.x = temp_n / elems_per_thread; grid.x = m; } else { block.x = max_threads_per_block; constexpr int elems_per_block = max_threads_per_block * elems_per_thread; grid.x = (m * temp_n + elems_per_block - 1) / elems_per_block; } generic_activation<Activation><<<grid, block, 0, stream>>>(reinterpret_cast<PT*>(up_out), reinterpret_cast<const PBT*>(bias), reinterpret_cast<const PT*>(gate), reinterpret_cast<const PBT*>(gate_bias), ia3_tasks, reinterpret_cast<const PT*>(ia3_weights), int8_mode, activation_in, activation_out, reinterpret_cast<const PBT*>(activation_scale), padding_offset, seq_len, m, temp_n / packed_elems, m * temp_n / packed_elems); } #define INSTANTIATE_GENERIC_ACTIVATION(Activation, T, BT) \ template void invokeGenericActivation<Activation, T, BT>(T * up_out, \ const BT* bias, \ const T* gate, \ const BT* gate_bias, \ const int* ia3_tasks, \ const T* ia3_weights, \ const int m, \ const int n, \ const int int8_mode, \ const float* activation_in, \ const float* activation_out, \ const BT* activation_scale, \ const int* padding_offset, \ const int seq_len, \ cudaStream_t stream); INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, float, float); INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, half, half); #ifdef ENABLE_BF16 INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, __nv_bfloat16, __nv_bfloat16); #endif INSTANTIATE_GENERIC_ACTIVATION(GeluActivationNoneApproximate, float, float); INSTANTIATE_GENERIC_ACTIVATION(GeluActivationNoneApproximate, half, half); #ifdef ENABLE_BF16 INSTANTIATE_GENERIC_ACTIVATION(GeluActivationNoneApproximate, __nv_bfloat16, __nv_bfloat16); #endif INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, float, float); INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, half, half); #ifdef ENABLE_BF16 INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, __nv_bfloat16, __nv_bfloat16); #endif INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, float, float); INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half); #ifdef ENABLE_BF16 INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16); #endif INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, float); INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, half, half); INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, half); #ifdef ENABLE_BF16 INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, __nv_bfloat16, __nv_bfloat16); INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, __nv_bfloat16); #endif #undef INSTANCIATE_GENERIC_ACTIVATION template<typename T> __global__ void add_bias_tanh(T* out, const T* __restrict bias, int m, int n) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { T val = out[id]; if (bias != nullptr) { val = val + ldg(&bias[id % n]); } out[id] = tanhf(val); } } template<> __global__ void add_bias_tanh(half* out, const half* __restrict bias, int m, int n) { half2* out_ptr = (half2*)out; const half2* bias_ptr = (half2*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { half2 val = out_ptr[id]; if (bias != nullptr) { val = val + __ldg(&bias_ptr[id % n]); } val.x = tanhf(val.x); val.y = tanhf(val.y); out_ptr[id] = val; } } #ifdef ENABLE_BF16 template<> __global__ void add_bias_tanh(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n) { __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { __nv_bfloat162 val = out_ptr[id]; if (bias != nullptr) { val = bf16hadd2(val, ldg(&bias_ptr[id % n])); } val.x = tanhf(val.x); val.y = tanhf(val.y); out_ptr[id] = val; } } #endif template<typename T> void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStream_t stream) { const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 dim3 block, grid; if (n / 4 / data_type_factor <= 1024) { block.x = n / 4 / data_type_factor; grid.x = m; } else { block.x = 1024; grid.x = ceil(m * n / 1024.); } add_bias_tanh<T><<<grid, block, 0, stream>>>(out, bias, m, n / data_type_factor); } template void invokeAddBiasTanh(float* out, const float* bias, const int m, const int n, cudaStream_t stream); template void invokeAddBiasTanh(half* out, const half* bias, const int m, const int n, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeAddBiasTanh(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); #endif template<typename T2, int N> __global__ void addBiasGeluV2(T2* out, const T2* __restrict bias, const int* ia3_tasks, const T2* ia3_weights, const int size, const int* padding_offset, const int seq_len) { const bool with_ia3 = ia3_tasks != nullptr; for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) { T2 val = out[id]; if (bias != nullptr) { T2 reg_bias = ldg(&bias[id % N]); val = hadd2(val, reg_bias); } val = GeluActivation<T2>::apply(val); if (with_ia3) { const int word_id = id / N; const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; const int batch_id = (word_id + offset) / seq_len; const int task = ia3_tasks[batch_id]; val = val * ia3_weights[task * N + (id % N)]; } out[id] = val; } } template<typename T2, int N, int ELEMENT_PER_ROUND> __global__ void addBiasGeluV3(T2* out, const T2* __restrict bias, const int* ia3_tasks, const T2* ia3_weights, const int size, const int* padding_offset, const int seq_len) { const bool with_ia3 = ia3_tasks != nullptr; T2 buffer[ELEMENT_PER_ROUND]; T2 tmp_bias[ELEMENT_PER_ROUND]; for (int id = blockIdx.x * blockDim.x * ELEMENT_PER_ROUND + threadIdx.x * ELEMENT_PER_ROUND; id < size; id += blockDim.x * gridDim.x * ELEMENT_PER_ROUND) { #pragma unroll for (int i = 0; i < ELEMENT_PER_ROUND; i++) { buffer[i] = out[id + i]; if (bias != nullptr) { tmp_bias[i] = ldg(&bias[(id + i) % N]); } } #pragma unroll for (int i = 0; i < ELEMENT_PER_ROUND; i++) { if (bias != nullptr) { buffer[i] = hadd2(buffer[i], tmp_bias[i]); } buffer[i] = GeluActivation<T2>::apply(buffer[i]); if (with_ia3) { const int word_id = (id + i) / N; const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; const int batch_id = (word_id + offset) / seq_len; const int task = ia3_tasks[batch_id]; buffer[i] = buffer[i] * ia3_weights[task * N + ((id + i) % N)]; } out[id + i] = buffer[i]; } } } #define ADD_BIAS_GELU(HALF_N, ELEMENT_PER_ROUND) \ case HALF_N: \ if (ELEMENT_PER_ROUND > 1) { \ grid.x = grid.x / ELEMENT_PER_ROUND; \ addBiasGeluV3<T2, HALF_N, ELEMENT_PER_ROUND><<<grid, block, 0, stream>>>( \ (T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \ } \ else { \ addBiasGeluV2<T2, HALF_N><<<grid, block, 0, stream>>>( \ (T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \ } \ break; template<typename T> void invokeAddBiasGeluV2(T* out, const T* bias, const int* ia3_tasks, const T* ia3_weights, const int* padding_offset, const int seq_len, const int m, const int n, cudaStream_t stream) { if (n % 2 == 0 && sizeof(T) == 2) { const int half_n = n / 2; dim3 block, grid; block.x = std::min(half_n, 512); grid.x = (m * half_n + (block.x - 1)) / block.x; using T2 = typename TypeConverter<T>::Type; if (grid.x >= 512) { switch (half_n) { ADD_BIAS_GELU(256, 1) ADD_BIAS_GELU(512, 1) ADD_BIAS_GELU(1024, 1) ADD_BIAS_GELU(1536, 1) ADD_BIAS_GELU(2048, 1) ADD_BIAS_GELU(4096, 2) ADD_BIAS_GELU(8192, 2) ADD_BIAS_GELU(16384, 2) ADD_BIAS_GELU(24576, 2) ADD_BIAS_GELU(40960, 4) default: invokeGenericActivation<GeluActivation>(out, bias, (T*)nullptr, (T*)nullptr, ia3_tasks, ia3_weights, m, n, 0, (float*)nullptr, (float*)nullptr, (T*)nullptr, padding_offset, seq_len, stream); break; } } else { switch (half_n) { ADD_BIAS_GELU(256, 1) ADD_BIAS_GELU(512, 1) ADD_BIAS_GELU(1024, 1) ADD_BIAS_GELU(1536, 1) ADD_BIAS_GELU(2048, 1) ADD_BIAS_GELU(4096, 1) ADD_BIAS_GELU(8192, 2) ADD_BIAS_GELU(16384, 2) ADD_BIAS_GELU(24576, 2) ADD_BIAS_GELU(40960, 2) default: invokeGenericActivation<GeluActivation>(out, bias, (T*)nullptr, (T*)nullptr, ia3_tasks, ia3_weights, m, n, 0, (float*)nullptr, (float*)nullptr, (T*)nullptr, padding_offset, seq_len, stream); break; } } } else { invokeGenericActivation<GeluActivation>(out, bias, (T*)nullptr, (T*)nullptr, ia3_tasks, ia3_weights, m, n, 0, (float*)nullptr, (float*)nullptr, (T*)nullptr, padding_offset, seq_len, stream); } } #undef ADD_BIAS_GELU template void invokeAddBiasGeluV2(float* out, const float* bias, const int* ia3_tasks, const float* ia3_weights, const int* padding_offset, const int seq_len, const int m, const int n, cudaStream_t stream); template void invokeAddBiasGeluV2(half* out, const half* bias, const int* ia3_tasks, const half* ia3_weights, const int* padding_offset, const int seq_len, const int m, const int n, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeAddBiasGeluV2(__nv_bfloat16* out, const __nv_bfloat16* bias, const int* ia3_tasks, const __nv_bfloat16* ia3_weights, const int* padding_offset, const int seq_len, const int m, const int n, cudaStream_t stream); #endif // ENABLE_BF16 template<typename T> __global__ void sigmoid_kernel(T* data, const int size, const float scale) { const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (index < size) { float val = cuda_cast<float>(data[index]); val = 1.0f / (1.0f + exp(-val)) * scale; data[index] = T(val); } } template<> __global__ void sigmoid_kernel(half2* data, const int size, const float scale) { const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (index < size / 2) { half2 val = data[index]; float2 val_float2 = cuda_cast<float2>(val); val_float2.x = 1.0f / (1.0f + exp(-val_float2.x)) * scale; val_float2.y = 1.0f / (1.0f + exp(-val_float2.y)) * scale; data[index] = cuda_cast<half2>(val_float2); } } #ifdef ENABLE_BF16 template<> __global__ void sigmoid_kernel(__nv_bfloat162* data, const int size, const float scale) { const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; if (index < size / 2) { __nv_bfloat162 val = data[index]; float2 val_float2 = cuda_cast<float2>(val); val_float2.x = 1.0f / (1.0f + exp(-val_float2.x)) * scale; val_float2.y = 1.0f / (1.0f + exp(-val_float2.y)) * scale; data[index] = cuda_cast<__nv_bfloat162>(val_float2); } } #endif template<typename T> void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream) { if (std::is_same<T, half>::value && (size %2 == 0)) { dim3 block(128); dim3 grid((size + 255) / 256); sigmoid_kernel<<<grid, block, 0, stream>>>((half2*)data, size, scale); } #ifdef ENABLE_BF16 else if (std::is_same<T, __nv_bfloat16>::value && (size %2 == 0)) { dim3 block(128); dim3 grid((size + 255) / 256); sigmoid_kernel<<<grid, block, 0, stream>>>((__nv_bfloat162*)data, size, scale); } #endif else { dim3 block(128); dim3 grid((size + 127) / 128); sigmoid_kernel<<<grid, block, 0, stream>>>(data, size, scale); } } template void invokeSigmoid(float* data, const int size, const float scale, cudaStream_t stream); template void invokeSigmoid(half* data, const int size, const float scale, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeSigmoid(__nv_bfloat16* data, const int size, const float scale, cudaStream_t stream); #endif template<typename T> __global__ void scaledot_kernel(T* out, const T* in, const T* scale, const int m, const int n) { const int size = m * n; const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; const int scale_index = index / n; if (index < size) { float val = cuda_cast<float>(in[index]); float scale_val = cuda_cast<float>(scale[scale_index]); out[index] = T(val * scale_val); } } template<typename T> void invokeScaledDot(T* out, const T* input, const T* scale, const int m, const int n, cudaStream_t stream) { int temp_n = n + n % 2; dim3 block, grid; if (temp_n <= 1024) { block.x = temp_n; grid.x = m; } else { block.x = 1024; grid.x = ceil(m * temp_n / 1024.); } scaledot_kernel<<<grid, block, 0, stream>>>(out, input, scale, m, n); } template void invokeScaledDot(float* out, const float* input, const float* scale, const int m, const int n, cudaStream_t stream); template void invokeScaledDot(half* out, const half* input, const half* scale, const int m, const int n, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeScaledDot(__nv_bfloat16* out, const __nv_bfloat16* input, const __nv_bfloat16* scale, const int m, const int n, cudaStream_t stream); #endif } // namespace rtp_llm