maga_transformer/cpp/kernels/_convert_to_float.h (75 lines of code) (raw):
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename A>
inline __device__ typename packed_type<float, num_elems<A>::value>::type convert_to_float(A u) {
return {};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 convert_to_float(float4 u) {
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 convert_to_float(float2 u) {
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float convert_to_float(float u) {
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ convert_to_float(uint4 u) {
Float8_ f8;
f8.x = half2_to_float2(u.x);
f8.y = half2_to_float2(u.y);
f8.z = half2_to_float2(u.z);
f8.w = half2_to_float2(u.w);
return f8;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 convert_to_float(uint2 u) {
float4 ret;
float2 f2x = half2_to_float2(u.x);
float2 f2y = half2_to_float2(u.y);
ret.x = f2x.x;
ret.y = f2x.y;
ret.z = f2y.x;
ret.w = f2y.y;
return ret;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 convert_to_float(uint32_t u) {
return half2_to_float2(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float convert_to_float(half u) {
return static_cast<float>(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template<>
inline __device__ float convert_to_float(__nv_bfloat16 u) {
return static_cast<float>(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 convert_to_float(__nv_bfloat162 u) {
return bf1622float2(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 convert_to_float(bf16_4_t u) {
float4 ret;
float2 f2x = bf1622float2(u.x);
float2 f2y = bf1622float2(u.y);
ret.x = f2x.x;
ret.y = f2x.y;
ret.z = f2y.x;
ret.w = f2y.y;
return ret;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ convert_to_float(bf16_8_t u) {
Float8_ f8;
f8.x = bf1622float2(u.x);
f8.y = bf1622float2(u.y);
f8.z = bf1622float2(u.z);
f8.w = bf1622float2(u.w);
return f8;
}
#endif