kernels/silu.cu (45 lines of code) (raw):

#include <cuda_fp16.h> #include <cuda_bf16.h> #include <stdint.h> __device__ __forceinline__ float expg(float a) { return expf(a); } __device__ __forceinline__ __half expg(__half a) { return hexp(a); } __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } template<typename scalar_t> inline __device__ scalar_t silu( scalar_t __restrict__ x) { return x / (static_cast<scalar_t>(1) + expg(-x)); } template<typename scalar_t> __global__ void silu_kernel( scalar_t* __restrict__ x_ptr, scalar_t* __restrict__ out_ptr, const int numel) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { out_ptr[i] = silu<scalar_t>(x_ptr[i]); } } #define CALL_SILU(T) \ silu_kernel<T><<<grid, block, 0, stream>>>( \ reinterpret_cast<T*>(x), \ reinterpret_cast<T*>(out), \ numel); extern "C" void silu( void *x, void *out, int32_t num_blocks, int32_t num_threads, int32_t numel, uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 ) { dim3 grid(num_blocks); dim3 block(num_threads); const cudaStream_t stream = 0; if (dtype == 0){ CALL_SILU(half); } else if (dtype == 1) { CALL_SILU(__nv_bfloat16); } else if (dtype == 2) { CALL_SILU(float); } }