maga_transformer/cpp/kernels/_convert_to_fp8.h (63 lines of code) (raw):

#pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const __nv_bfloat16 u) { v[0] = __nv_fp8_e4m3(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_2_t* v, const __nv_bfloat162 u) { v[0] = fp8_2_t(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_4_t* v, const bf16_4_t u) { reinterpret_cast<fp8_2_t*>(v)[0] = fp8_2_t(u.x); reinterpret_cast<fp8_2_t*>(v)[1] = fp8_2_t(u.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_8_t* v, const bf16_8_t u) { v[0].x = fp8_2_t(u.x); v[0].y = fp8_2_t(u.y); v[0].z = fp8_2_t(u.z); v[0].w = fp8_2_t(u.w); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const half u) { v[0] = __nv_fp8_e4m3(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const uint16_t u) { v[0] = __nv_fp8_e4m3(reinterpret_cast<const half&>(u)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_2_t* v, const uint32_t u) { v[0] = fp8_2_t(reinterpret_cast<const half2&>(u)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_4_t* v, const uint2 u) { union { uint2 u2; half2 h2[2]; }; u2 = u; reinterpret_cast<fp8_2_t*>(v)[0] = fp8_2_t(h2[0]); reinterpret_cast<fp8_2_t*>(v)[1] = fp8_2_t(h2[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_8_t* v, const uint4 u) { union { uint4 u4; half2 h2[4]; }; u4 = u; v[0].x = fp8_2_t(h2[0]); v[0].y = fp8_2_t(h2[1]); v[0].z = fp8_2_t(h2[2]); v[0].w = fp8_2_t(h2[3]); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const float u) { v[0] = __nv_fp8_e4m3(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_2_t* v, const float2 u) { v[0] = fp8_2_t(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_4_t* v, const float4 u) { v[0] = fp8_4_t(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_8_t* v, const Float8_ u) { v[0].x = fp8_2_t(u.x); v[0].y = fp8_2_t(u.y); v[0].z = fp8_2_t(u.z); v[0].w = fp8_2_t(u.w); } #endif // ENABLE_FP8