maga_transformer/cpp/kernels/_fma.h (525 lines of code) (raw):

#pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float fma(float a, float b, float c) { return a * b + c; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(float a, float2 b, float2 c) { float2 d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 fma(float4 a, float4 b, float4 c) { float4 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); d.z = fma(a.z, b.z, c.z); d.w = fma(a.w, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(float4 a, Float4_ fb, Float4_ fc) { Float4_ fa, fd; fa = reinterpret_cast<Float4_&>(a); fd.x = fma(fa.x, fb.x, fc.x); fd.y = fma(fa.y, fb.y, fc.y); return fd; } ////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { Float8_ d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); d.z = fma(a.z, b.z, c.z); d.w = fma(a.w, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 fma(float a, float4 b, float4 c) { float4 d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); d.z = fma(a, b.z, c.z); d.w = fma(a, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 fma(float a, float4 b, Float4_ c) { float4 d; d.x = fma(a, b.x, c.x.x); d.y = fma(a, b.y, c.x.y); d.z = fma(a, b.z, c.y.x); d.w = fma(a, b.w, c.y.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { Float4_ d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { Float8_ d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); d.z = fma(a, b.z, c.z); d.w = fma(a, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return fma(h0_h0(a), b, c); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { uint2 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { uint32_t s = h0_h0(a); uint2 d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { uint4 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); d.z = fma(a.z, b.z, c.z); d.w = fma(a.w, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { uint32_t s = h0_h0(a); uint4 d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); d.z = fma(s, b.z, c.z); d.w = fma(s, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float fma(uint16_t a, uint16_t b, float fc) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb + fc; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return fma(fa, fb, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(float2 fa, uint32_t b, float2 fc) { float2 fb = half2_to_float2(b); return fma(fa, fb, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return fma(h0_h0(a), b, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { Float4_ fd; fd.x = fma(a.x, b.x, fc.x); fd.y = fma(a.y, b.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { uint32_t s = h0_h0(a); Float4_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { Float8_ fd; fd.x = fma(a.x, b.x, fc.x); fd.y = fma(a.y, b.y, fc.y); fd.z = fma(a.z, b.z, fc.z); fd.w = fma(a.w, b.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(Float8_ fa, uint4 b, Float8_ fc) { Float8_ fd; fd.x = fma(fa.x, b.x, fc.x); fd.y = fma(fa.y, b.y, fc.y); fd.z = fma(fa.z, b.z, fc.z); fd.w = fma(fa.w, b.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { uint32_t s = h0_h0(a); Float8_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); fd.z = fma(s, b.z, fc.z); fd.w = fma(s, b.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float fma(uint16_t a, float fb, float fc) { float fa = half_to_float(a); return fa * fb + fc; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(uint32_t a, float2 fb, float2 fc) { float2 fa = half2_to_float2(a); return fma(fa, fb, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(uint16_t a, float2 fb, float2 fc) { float fa = half_to_float(a); float2 fd; fd.x = fma(fa, fb.x, fc.x); fd.y = fma(fa, fb.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(uint2 a, Float4_ fb, Float4_ fc) { Float4_ fd; fd.x = fma(a.x, fb.x, fc.x); fd.y = fma(a.y, fb.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(uint16_t a, Float4_ fb, Float4_ fc) { Float4_ fd; fd.x = fma(a, fb.x, fc.x); fd.y = fma(a, fb.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint4 a, Float8_ fb, Float8_ fc) { Float8_ fd; fd.x = fma(a.x, fb.x, fc.x); fd.y = fma(a.y, fb.y, fc.y); fd.z = fma(a.z, fb.z, fc.z); fd.w = fma(a.w, fb.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint16_t a, Float8_ fb, Float8_ fc) { Float8_ fd; fd.x = fma(a, fb.x, fc.x); fd.y = fma(a, fb.y, fc.y); fd.z = fma(a, fb.z, fc.z); fd.w = fma(a, fb.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { return bf16hfma2(a, b, c); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { return bf16hfma2(bf162bf162(a), b, c); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { bf16_4_t d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { bf16_8_t d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); d.z = fma(a.z, b.z, c.z); d.w = fma(a.w, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); d.z = fma(s, b.z, c.z); d.w = fma(s, b.w, c.w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { return __bfloat162float(a) * __bfloat162float(b) + fc; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return fma(fa, fb, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(float2 fa, __nv_bfloat162 b, float2 fc) { float2 fb = bf1622float2(b); return fma(fa, fb, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { return fma(bf162bf162(a), b, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { Float4_ fd; fd.x = fma(a.x, b.x, fc.x); fd.y = fma(a.y, b.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { __nv_bfloat162 s = bf162bf162(a); Float4_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { Float8_ fd; fd.x = fma(a.x, b.x, fc.x); fd.y = fma(a.y, b.y, fc.y); fd.z = fma(a.z, b.z, fc.z); fd.w = fma(a.w, b.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(Float8_ fa, bf16_8_t b, Float8_ fc) { Float8_ fd; fd.x = fma(fa.x, b.x, fc.x); fd.y = fma(fa.y, b.y, fc.y); fd.z = fma(fa.z, b.z, fc.z); fd.w = fma(fa.w, b.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { __nv_bfloat162 s = bf162bf162(a); Float8_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); fd.z = fma(s, b.z, fc.z); fd.w = fma(s, b.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float fma(__nv_bfloat16 a, float fb, float fc) { float fa = __bfloat162float(a); return fa * fb + fc; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(__nv_bfloat162 a, float2 fb, float2 fc) { float2 fa = bf1622float2(a); return fma(fa, fb, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 fma(__nv_bfloat16 a, float2 fb, float2 fc) { float fa = __bfloat162float(a); return fma(fa, fb, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(bf16_4_t a, Float4_ fb, Float4_ fc) { Float4_ fd; fd.x = fma(a.x, fb.x, fc.x); fd.y = fma(a.y, fb.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(__nv_bfloat16 a, Float4_ fb, Float4_ fc) { Float4_ fd; fd.x = fma(a, fb.x, fc.x); fd.y = fma(a, fb.y, fc.y); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(bf16_8_t a, Float8_ fb, Float8_ fc) { Float8_ fd; fd.x = fma(a.x, fb.x, fc.x); fd.y = fma(a.y, fb.y, fc.y); fd.z = fma(a.z, fb.z, fc.z); fd.w = fma(a.w, fb.w, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(__nv_bfloat16 a, Float8_ fb, Float8_ fc) { Float8_ fd; fd.x = fma(a, fb.x, fc.x); fd.y = fma(a, fb.y, fc.y); fd.z = fma(a, fb.z, fc.z); fd.w = fma(a, fb.w, fc.w); return fd; } #endif // ENABLE_BF16 #ifdef ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 fma(float4 a, fp8_4_t b, float4 fc) { float4 fd; union { fp8_4_t fp8_4; fp8_2_t fp8_2[2]; }; fp8_4 = b; float2 fb0 = float2(fp8_2[0]); float2 fb1 = float2(fp8_2[1]); fd.x = fma(a.x, fb0.x, fc.x); fd.y = fma(a.y, fb0.y, fc.y); fd.z = fma(a.z, fb1.x, fc.z); fd.w = fma(a.w, fb1.y, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 fma(float a, fp8_4_t b, float4 fc) { float4 fd; union { fp8_4_t fp8_4; fp8_2_t fp8_2[2]; }; fp8_4 = b; float2 fb0 = float2(fp8_2[0]); float2 fb1 = float2(fp8_2[1]); fd.x = fma(a, fb0.x, fc.x); fd.y = fma(a, fb0.y, fc.y); fd.z = fma(a, fb1.x, fc.z); fd.w = fma(a, fb1.y, fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(float4 a, fp8_4_t b, Float4_ fc) { float4 fd; fd = fma(a, b, reinterpret_cast<float4&>(fc)); return reinterpret_cast<Float4_&>(fd); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint4 a, fp8_8_t b, Float8_ fc) { Float8_ fd; union { fp8_8_t fp8_8; fp8_2_t fp8_2[4]; }; fp8_8 = b; fd.x = fma(a.x, float2(fp8_2[0]), fc.x); fd.y = fma(a.y, float2(fp8_2[1]), fc.y); fd.z = fma(a.z, float2(fp8_2[2]), fc.z); fd.w = fma(a.w, float2(fp8_2[3]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(Float8_ fa, fp8_8_t b, Float8_ fc) { Float8_ fd; union { fp8_8_t fp8_8; fp8_2_t fp8_2[4]; }; fp8_8 = b; fd.x = fma(fa.x, float2(fp8_2[0]), fc.x); fd.y = fma(fa.y, float2(fp8_2[1]), fc.y); fd.z = fma(fa.z, float2(fp8_2[2]), fc.z); fd.w = fma(fa.w, float2(fp8_2[3]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(float a, fp8_8_t b, Float8_ fc) { Float8_ fd; union { fp8_8_t fp8_8; fp8_2_t fp8_2[4]; }; fp8_8 = b; fd.x = fma(a, float2(fp8_2[0]), fc.x); fd.y = fma(a, float2(fp8_2[1]), fc.y); fd.z = fma(a, float2(fp8_2[2]), fc.z); fd.w = fma(a, float2(fp8_2[3]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint16_t a, fp8_8_t b, Float8_ fc) { return fma(half_to_float(a), b, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(bf16_8_t a, fp8_8_t b, Float8_ fc) { Float8_ fd; union { fp8_8_t fp8_8; fp8_2_t fp8_2[4]; }; fp8_8 = b; fd.x = fma(a.x, float2(fp8_2[0]), fc.x); fd.y = fma(a.y, float2(fp8_2[1]), fc.y); fd.z = fma(a.z, float2(fp8_2[2]), fc.z); fd.w = fma(a.w, float2(fp8_2[3]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(__nv_bfloat16 a, fp8_8_t b, Float8_ fc) { return fma(__bfloat162float(a), b, fc); } #endif // ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 fma(float4 a, int32_t b, float4 fc) { float4 fd; union { int32_t int32; ; int8_t int8[4]; }; int32 = b; fd.x = fma(a.x, int8[0], fc.x); fd.y = fma(a.y, int8[1], fc.y); fd.z = fma(a.z, int8[2], fc.z); fd.w = fma(a.w, int8[3], fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ fma(float4 a, int32_t b, Float4_ fc) { float4 fd; fd = fma(a, b, reinterpret_cast<float4&>(fc)); return reinterpret_cast<Float4_&>(fd); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 fma(float a, int32_t b, float4 fc) { float4 fd; union { int32_t int32; ; int8_t int8[4]; }; int32 = b; fd.x = fma(a, int8[0], fc.x); fd.y = fma(a, int8[1], fc.y); fd.z = fma(a, int8[2], fc.z); fd.w = fma(a, int8[3], fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint4 a, int64_t b, Float8_ fc) { Float8_ fd; union { int64_t int64; int8_t int8[8]; }; int64 = b; fd.x = fma(a.x, make_float2(int8[0], int8[1]), fc.x); fd.y = fma(a.y, make_float2(int8[2], int8[3]), fc.y); fd.z = fma(a.z, make_float2(int8[4], int8[5]), fc.z); fd.w = fma(a.w, make_float2(int8[6], int8[7]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(Float8_ fa, int64_t b, Float8_ fc) { Float8_ fd; union { int64_t int64; int8_t int8[8]; }; int64 = b; fd.x = fma(fa.x, make_float2(int8[0], int8[1]), fc.x); fd.y = fma(fa.y, make_float2(int8[2], int8[3]), fc.y); fd.z = fma(fa.z, make_float2(int8[4], int8[5]), fc.z); fd.w = fma(fa.w, make_float2(int8[6], int8[7]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(float a, int64_t b, Float8_ fc) { Float8_ fd; float2 fa = make_float2(a, a); union { int64_t int64; int8_t int8[8]; }; int64 = b; fd.x = fma(fa, make_float2(int8[0], int8[1]), fc.x); fd.y = fma(fa, make_float2(int8[2], int8[3]), fc.y); fd.z = fma(fa, make_float2(int8[4], int8[5]), fc.z); fd.w = fma(fa, make_float2(int8[6], int8[7]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(uint16_t a, int64_t b, Float8_ fc) { return fma(half_to_float(a), b, fc); } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 inline __device__ Float8_ fma(bf16_8_t a, int64_t b, Float8_ fc) { Float8_ fd; union { int64_t int64; int8_t int8[8]; }; int64 = b; fd.x = fma(a.x, make_float2(int8[0], int8[1]), fc.x); fd.y = fma(a.y, make_float2(int8[2], int8[3]), fc.y); fd.z = fma(a.z, make_float2(int8[4], int8[5]), fc.z); fd.w = fma(a.w, make_float2(int8[6], int8[7]), fc.w); return fd; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ fma(__nv_bfloat16 a, int64_t b, Float8_ fc) { return fma(__bfloat162float(a), b, fc); } #endif