maga_transformer/cpp/kernels/_sum_dot_zero.h (80 lines of code) (raw):

#pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(float v) { return v; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(float2 v) { return v.x + v.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(Float8_ v) { float out = 0.f; out += sum(v.x); out += sum(v.y); out += sum(v.z); out += sum(v.w); return out; } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } #endif // ENABLE_BF16 //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(uint16_t v) { return half_to_float(v); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(uint4 v) { #if 1 uint32_t c = add(v.x, v.y); c = add(c, v.z); c = add(c, v.w); #else uint32_t c = add(v.x, v.y); uint32_t d = add(v.z, v.w); c = add(c, d); #endif return sum(c); } //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename T> inline __device__ float dot(T a, T b) { return sum(mul<T, T, T>(a, b)); } //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename A, typename T> inline __device__ float dot(T a, T b) { return sum(mul<A, T, T>(a, b)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename T> inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { T raw; uint32_t words[WORDS]; } tmp; #pragma unroll for (int ii = 0; ii < WORDS; ++ii) { tmp.words[ii] = 0u; } dst = tmp.raw; }