maga_transformer/cpp/kernels/_sum_dot_zero.h (80 lines of code) (raw):
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float v) {
return v;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float2 v) {
return v.x + v.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(Float8_ v) {
float out = 0.f;
out += sum(v.x);
out += sum(v.y);
out += sum(v.z);
out += sum(v.w);
return out;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint16_t v) {
return half_to_float(v);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint4 v) {
#if 1
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
#else
uint32_t c = add(v.x, v.y);
uint32_t d = add(v.z, v.w);
c = add(c, d);
#endif
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}